Fixing some types and paths.

This commit is contained in:
RD 2024-10-19 20:11:27 -07:00
parent 14c35c061f
commit 68061a3317
11 changed files with 60 additions and 63 deletions

View file

@ -32,7 +32,7 @@
<excludeFolder url="file://$MODULE_DIR$/js/cf-webapp/node_modules" />
<excludeFolder url="file://$MODULE_DIR$/js/common/node_modules" />
</content>
<orderEntry type="jdk" jdkName="$USER_HOME$/miniforge3/envs/codeflash311" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="codeflash312" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="module" module-name="langchain" />
</component>

View file

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="PROD postgres@codeflash-pgsql-db-prod.postgres.database.azure.com" uuid="b1eb5197-a44e-453c-bc29-e38a66e6bdad">
<data-source source="LOCAL" name="postgres@codeflash-pgsql-db-prod.postgres.database.azure.com" uuid="b1eb5197-a44e-453c-bc29-e38a66e6bdad">
<driver-ref>postgresql</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.postgresql.Driver</jdbc-driver>

View file

@ -19,6 +19,5 @@
</option>
</inspection_tool>
<inspection_tool class="PyTypeCheckerInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="RuffInspection" enabled="true" level="SERVER PROBLEM" enabled_by_default="true" editorAttributes="RUNTIME_ERROR" />
</profile>
</component>

View file

@ -2,8 +2,8 @@
<project version="4">
<component name="RuffConfigService">
<option name="alwaysUseGlobalRuff" value="true" />
<option name="projectRuffExecutablePath" value="$USER_HOME$/miniforge3/envs/codeflash311/bin/ruff" />
<option name="projectRuffLspExecutablePath" value="$USER_HOME$/miniforge3/envs/codeflash311/bin/ruff-lsp" />
<option name="projectRuffExecutablePath" value="$USER_HOME$/miniforge3/envs/codeflash312/bin/ruff" />
<option name="projectRuffLspExecutablePath" value="$USER_HOME$/miniforge3/envs/codeflash312/bin/ruff-lsp" />
<option name="ruffConfigPath" value="$PROJECT_DIR$/django/aiservice/pyproject.toml" />
<option name="runRuffOnSave" value="true" />
<option name="useRuffFormat" value="true" />

View file

@ -9,7 +9,7 @@
<env name="DJANGO_SETTINGS_MODULE" value="aiservice.settings" />
</envs>
<option name="SDK_HOME" value="$USER_HOME$/mambaforge/envs/aiservice/bin/python" />
<option name="SDK_NAME" value="$USER_HOME$/miniforge3/envs/aiservice" />
<option name="SDK_NAME" value="aiservice312" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/django/aiservice" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />

View file

@ -8,7 +8,7 @@
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="$USER_HOME$/miniforge3/envs/codeflash311" />
<option name="SDK_NAME" value="codeflash312" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/cli" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
@ -26,7 +26,7 @@
</ENTRIES>
</EXTENSION>
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/cli/codeflash/main.py" />
<option name="PARAMETERS" value="--file code_to_optimize/bubble_sort.py --function sorter --test-framework unittest --tests-root code_to_optimize/tests/unittest --use-cached-tests" />
<option name="PARAMETERS" value="--verbose --file code_to_optimize/bubble_sort.py --function sorter --test-framework unittest --tests-root code_to_optimize/tests/unittest --module-root $PROJECT_DIR$/cli" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />

View file

@ -147,7 +147,7 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
return True
def get_blacklisted_functions() -> dict[str, str]:
def get_blocklisted_functions() -> dict[str, str]:
pr_number = get_pr_number()
if pr_number is None:
return {}
@ -166,6 +166,6 @@ def get_blacklisted_functions() -> dict[str, str]:
)
content: dict[str, list[str]] = req.json()
except Exception as e:
logger.error(f"Error getting blacklisted functions: {e}")
logger.error(f"Error getting blocklisted functions: {e}")
return {}
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}

View file

@ -681,8 +681,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
)
finally:
# Delete the bubble_sort.py file after the test
pathlib.Path(bubble_sort_path).unlink(missing_ok=True)
pathlib.Path(bubble_sort_test_path).unlink(missing_ok=True)
Path(bubble_sort_path).unlink(missing_ok=True)
Path(bubble_sort_test_path).unlink(missing_ok=True)
click.echo(f"{LF}🗑️ Deleted {bubble_sort_path}")
click.echo(f"{LF}🗑️ Deleted {bubble_sort_test_path}")

View file

@ -1,11 +1,12 @@
from __future__ import annotations
import os
import re
import shlex
import unittest
from collections import defaultdict
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import jedi
from pydantic.dataclasses import dataclass
@ -13,7 +14,9 @@ from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.verification.test_results import TestType
from codeflash.verification.verification_utils import TestConfig
if TYPE_CHECKING:
from codeflash.verification.verification_utils import TestConfig
@dataclass(frozen=True)
@ -34,7 +37,7 @@ class CodePosition:
@dataclass(frozen=True)
class FunctionCalledInTest:
test_file: Path
test_class: Optional[str] # This might be unused...
test_class: Optional[str] # This might be unused
test_function: str
test_suite: Optional[str]
test_type: TestType
@ -51,27 +54,25 @@ class TestFunction:
def discover_unit_tests(
cfg: TestConfig,
discover_only_these_tests: Optional[List[str]] = None,
) -> Dict[str, List[FunctionCalledInTest]]:
test_frameworks = {
"pytest": discover_tests_pytest,
"unittest": discover_tests_unittest,
}
discover_tests = test_frameworks.get(cfg.test_framework)
if discover_tests is None:
raise ValueError(f"Unsupported test framework: {cfg.test_framework}")
return discover_tests(cfg, discover_only_these_tests)
discover_only_these_tests: list[str] | None = None,
) -> dict[str, list[FunctionCalledInTest]]:
if cfg.test_framework == "pytest":
return discover_tests_pytest(cfg, discover_only_these_tests)
if cfg.test_framework == "unittest":
return discover_tests_unittest(cfg, discover_only_these_tests)
msg = f"Unsupported test framework: {cfg.test_framework}"
raise ValueError(msg)
def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) -> Tuple[int, List]:
def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) -> tuple[int, list] | None:
import pytest
os.chdir(cwd)
collected_tests = []
tests = []
tests: list[TestsInFile] = []
class PytestCollectionPlugin:
def pytest_collection_finish(self, session):
def pytest_collection_finish(self, session) -> None:
collected_tests.extend(session.items)
try:
@ -89,8 +90,8 @@ def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) ->
def parse_pytest_collection_results(
pytest_tests: str,
) -> List[TestsInFile]:
test_results = []
) -> list[TestsInFile]:
test_results: list[TestsInFile] = []
for test in pytest_tests:
test_class = None
test_file_path = str(test.path)
@ -111,17 +112,13 @@ def parse_pytest_collection_results(
def discover_tests_pytest(
cfg: TestConfig,
discover_only_these_tests: Optional[List[str]] = None,
) -> Dict[str, List[FunctionCalledInTest]]:
discover_only_these_tests: list[str] | None = None,
) -> dict[str, list[FunctionCalledInTest]]:
tests_root = cfg.tests_root
project_root = cfg.project_root_path
pytest_cmd_list = shlex.split(
cfg.pytest_cmd,
posix=os.name != "nt",
) # TODO: Do we need this for test collection?
q = Queue()
p = Process(target=run_pytest_discovery_new_process, args=(q, project_root, tests_root))
q: Queue = Queue()
p: Process = Process(target=run_pytest_discovery_new_process, args=(q, project_root, tests_root))
p.start()
exitcode, tests = q.get()
p.join()
@ -138,14 +135,14 @@ def discover_tests_pytest(
def discover_tests_unittest(
cfg: TestConfig,
discover_only_these_tests: Optional[List[str]] = None,
) -> Dict[str, List[FunctionCalledInTest]]:
tests_root = Path(cfg.tests_root)
loader = unittest.TestLoader()
tests = loader.discover(str(tests_root))
file_to_test_map = defaultdict(list)
discover_only_these_tests: list[str] | None = None,
) -> dict[str, list[FunctionCalledInTest]]:
tests_root: Path = cfg.tests_root
loader: unittest.TestLoader = unittest.TestLoader()
tests: unittest.TestSuite = loader.discover(str(tests_root))
file_to_test_map: defaultdict[str, list[TestsInFile]] = defaultdict(list)
def get_test_details(_test) -> Optional[TestsInFile]:
def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
_test_function, _test_module, _test_suite_name = (
_test._testMethodName,
_test.__class__.__module__,
@ -195,7 +192,7 @@ def discover_tests_unittest(
return process_test_files(file_to_test_map, cfg)
def discover_parameters_unittest(function_name: str):
def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
function_name = function_name.split("_")
if len(function_name) > 1 and function_name[-1].isdigit():
return True, "_".join(function_name[:-1]), function_name[-1]
@ -204,9 +201,9 @@ def discover_parameters_unittest(function_name: str):
def process_test_files(
file_to_test_map: Dict[str, List[TestsInFile]],
file_to_test_map: dict[str, list[TestsInFile]],
cfg: TestConfig,
) -> Dict[str, List[FunctionCalledInTest]]:
) -> dict[str, list[FunctionCalledInTest]]:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework
function_to_test_map = defaultdict(list)
@ -224,8 +221,8 @@ def process_test_files(
functions_to_search = [elem.test_function for elem in functions]
for i, function in enumerate(functions_to_search):
if "[" in function:
function_name = re.split(r"\[|\]", function)[0]
parameters = re.split(r"\[|\]", function)[1]
function_name = re.split(r"[\[\]]", function)[0]
parameters = re.split(r"[\[\]]", function)[1]
if name.name == function_name and name.type == "function":
test_functions.add(
TestFunction(name.name, None, parameters, functions[i].test_type),

View file

@ -7,15 +7,13 @@ from _ast import AsyncFunctionDef, ClassDef, FunctionDef
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional
import git
import libcst as cst
from libcst import CSTNode
from libcst.metadata import CodeRange
from pydantic.dataclasses import dataclass
from codeflash.api.cfapi import get_blacklisted_functions
from codeflash.api.cfapi import get_blocklisted_functions
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
@ -25,7 +23,12 @@ from codeflash.code_utils.code_utils import (
from codeflash.code_utils.git_utils import get_git_diff
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
if TYPE_CHECKING:
from libcst import CSTNode
from libcst.metadata import CodeRange
from codeflash.verification.verification_utils import TestConfig
@dataclass(frozen=True)
@ -109,7 +112,7 @@ class FunctionParent:
type: str
@dataclass(frozen=True, config=dict(arbitrary_types_allowed=True))
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class FunctionToOptimize:
"""Represents a function that is a candidate for optimization.
@ -133,9 +136,6 @@ class FunctionToOptimize:
starting_line: Optional[int] = None
ending_line: Optional[int] = None
# # For "BubbleSort.sorter", returns "BubbleSort"
# # For "sorter", returns "sorter"
# # TODO: does not support nested classes or functions
@property
def top_level_parent_name(self) -> str:
return self.function_name if not self.parents else self.parents[0].name
@ -446,7 +446,7 @@ def filter_functions(
module_root: Path,
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
blocklist_funcs = get_blacklisted_functions()
blocklist_funcs = get_blocklisted_functions()
# Remove any function that we don't want to optimize
# Ignore files with submodule path, cache the submodule paths
@ -469,7 +469,8 @@ def filter_functions(
continue
if file_path in ignore_paths or any(
# file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
file_path.startswith(str(ignore_path) + os.sep)
for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue

View file

@ -763,7 +763,7 @@ class Optimizer:
continue
new_test_path = Path(
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}"
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}",
)
if injected_test is not None:
with new_test_path.open("w", encoding="utf8") as _f: