diff --git a/.idea/codeflash.iml b/.idea/codeflash.iml
index d8fbc86ef..c56433c64 100644
--- a/.idea/codeflash.iml
+++ b/.idea/codeflash.iml
@@ -32,7 +32,7 @@
-
+
diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml
index b7313f620..eedd195a9 100644
--- a/.idea/dataSources.xml
+++ b/.idea/dataSources.xml
@@ -1,7 +1,7 @@
-
+
postgresql
true
org.postgresql.Driver
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
index 0b914ba31..ca86ec240 100644
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -19,6 +19,5 @@
-
\ No newline at end of file
diff --git a/.idea/ruff.xml b/.idea/ruff.xml
index 3db20b6ec..f7fb96578 100644
--- a/.idea/ruff.xml
+++ b/.idea/ruff.xml
@@ -2,8 +2,8 @@
-
-
+
+
diff --git a/.idea/runConfigurations/aiservice.xml b/.idea/runConfigurations/aiservice.xml
index a97aaa082..ed2d240ce 100644
--- a/.idea/runConfigurations/aiservice.xml
+++ b/.idea/runConfigurations/aiservice.xml
@@ -9,7 +9,7 @@
-
+
diff --git a/.idea/runConfigurations/bubble_sort_cached_tests.xml b/.idea/runConfigurations/bubble_sort_cached_tests.xml
index 9581f047f..52fb12221 100644
--- a/.idea/runConfigurations/bubble_sort_cached_tests.xml
+++ b/.idea/runConfigurations/bubble_sort_cached_tests.xml
@@ -8,7 +8,7 @@
-
+
@@ -26,7 +26,7 @@
-
+
diff --git a/cli/codeflash/api/cfapi.py b/cli/codeflash/api/cfapi.py
index 2cea6b7ca..29b7c3ba4 100644
--- a/cli/codeflash/api/cfapi.py
+++ b/cli/codeflash/api/cfapi.py
@@ -147,7 +147,7 @@ def is_github_app_installed_on_repo(owner: str, repo: str) -> bool:
return True
-def get_blacklisted_functions() -> dict[str, str]:
+def get_blocklisted_functions() -> dict[str, str]:
pr_number = get_pr_number()
if pr_number is None:
return {}
@@ -166,6 +166,6 @@ def get_blacklisted_functions() -> dict[str, str]:
)
content: dict[str, list[str]] = req.json()
except Exception as e:
- logger.error(f"Error getting blacklisted functions: {e}")
+ logger.error(f"Error getting blocklisted functions: {e}")
return {}
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
diff --git a/cli/codeflash/cli_cmds/cmd_init.py b/cli/codeflash/cli_cmds/cmd_init.py
index 8185c844a..dfc631eea 100644
--- a/cli/codeflash/cli_cmds/cmd_init.py
+++ b/cli/codeflash/cli_cmds/cmd_init.py
@@ -681,8 +681,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
)
finally:
# Delete the bubble_sort.py file after the test
- pathlib.Path(bubble_sort_path).unlink(missing_ok=True)
- pathlib.Path(bubble_sort_test_path).unlink(missing_ok=True)
+ Path(bubble_sort_path).unlink(missing_ok=True)
+ Path(bubble_sort_test_path).unlink(missing_ok=True)
click.echo(f"{LF}🗑️ Deleted {bubble_sort_path}")
click.echo(f"{LF}🗑️ Deleted {bubble_sort_test_path}")
diff --git a/cli/codeflash/discovery/discover_unit_tests.py b/cli/codeflash/discovery/discover_unit_tests.py
index 19e8cd9b5..b177734a4 100644
--- a/cli/codeflash/discovery/discover_unit_tests.py
+++ b/cli/codeflash/discovery/discover_unit_tests.py
@@ -1,11 +1,12 @@
+from __future__ import annotations
+
import os
import re
-import shlex
import unittest
from collections import defaultdict
from multiprocessing import Process, Queue
from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Optional
import jedi
from pydantic.dataclasses import dataclass
@@ -13,7 +14,9 @@ from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.verification.test_results import TestType
-from codeflash.verification.verification_utils import TestConfig
+
+if TYPE_CHECKING:
+ from codeflash.verification.verification_utils import TestConfig
@dataclass(frozen=True)
@@ -34,7 +37,7 @@ class CodePosition:
@dataclass(frozen=True)
class FunctionCalledInTest:
test_file: Path
- test_class: Optional[str] # This might be unused...
+ test_class: Optional[str] # This might be unused…
test_function: str
test_suite: Optional[str]
test_type: TestType
@@ -51,27 +54,25 @@ class TestFunction:
def discover_unit_tests(
cfg: TestConfig,
- discover_only_these_tests: Optional[List[str]] = None,
-) -> Dict[str, List[FunctionCalledInTest]]:
- test_frameworks = {
- "pytest": discover_tests_pytest,
- "unittest": discover_tests_unittest,
- }
- discover_tests = test_frameworks.get(cfg.test_framework)
- if discover_tests is None:
- raise ValueError(f"Unsupported test framework: {cfg.test_framework}")
- return discover_tests(cfg, discover_only_these_tests)
+ discover_only_these_tests: list[str] | None = None,
+) -> dict[str, list[FunctionCalledInTest]]:
+ if cfg.test_framework == "pytest":
+ return discover_tests_pytest(cfg, discover_only_these_tests)
+ if cfg.test_framework == "unittest":
+ return discover_tests_unittest(cfg, discover_only_these_tests)
+ msg = f"Unsupported test framework: {cfg.test_framework}"
+ raise ValueError(msg)
-def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) -> Tuple[int, List]:
+def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) -> tuple[int, list] | None:
import pytest
os.chdir(cwd)
collected_tests = []
- tests = []
+ tests: list[TestsInFile] = []
class PytestCollectionPlugin:
- def pytest_collection_finish(self, session):
+ def pytest_collection_finish(self, session) -> None:
collected_tests.extend(session.items)
try:
@@ -89,8 +90,8 @@ def run_pytest_discovery_new_process(queue: Queue, cwd: str, tests_root: str) ->
def parse_pytest_collection_results(
pytest_tests: str,
-) -> List[TestsInFile]:
- test_results = []
+) -> list[TestsInFile]:
+ test_results: list[TestsInFile] = []
for test in pytest_tests:
test_class = None
test_file_path = str(test.path)
@@ -111,17 +112,13 @@ def parse_pytest_collection_results(
def discover_tests_pytest(
cfg: TestConfig,
- discover_only_these_tests: Optional[List[str]] = None,
-) -> Dict[str, List[FunctionCalledInTest]]:
+ discover_only_these_tests: list[str] | None = None,
+) -> dict[str, list[FunctionCalledInTest]]:
tests_root = cfg.tests_root
project_root = cfg.project_root_path
- pytest_cmd_list = shlex.split(
- cfg.pytest_cmd,
- posix=os.name != "nt",
- ) # TODO: Do we need this for test collection?
- q = Queue()
- p = Process(target=run_pytest_discovery_new_process, args=(q, project_root, tests_root))
+ q: Queue = Queue()
+ p: Process = Process(target=run_pytest_discovery_new_process, args=(q, project_root, tests_root))
p.start()
exitcode, tests = q.get()
p.join()
@@ -138,14 +135,14 @@ def discover_tests_pytest(
def discover_tests_unittest(
cfg: TestConfig,
- discover_only_these_tests: Optional[List[str]] = None,
-) -> Dict[str, List[FunctionCalledInTest]]:
- tests_root = Path(cfg.tests_root)
- loader = unittest.TestLoader()
- tests = loader.discover(str(tests_root))
- file_to_test_map = defaultdict(list)
+ discover_only_these_tests: list[str] | None = None,
+) -> dict[str, list[FunctionCalledInTest]]:
+ tests_root: Path = cfg.tests_root
+ loader: unittest.TestLoader = unittest.TestLoader()
+ tests: unittest.TestSuite = loader.discover(str(tests_root))
+ file_to_test_map: defaultdict[str, list[TestsInFile]] = defaultdict(list)
- def get_test_details(_test) -> Optional[TestsInFile]:
+ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
_test_function, _test_module, _test_suite_name = (
_test._testMethodName,
_test.__class__.__module__,
@@ -195,7 +192,7 @@ def discover_tests_unittest(
return process_test_files(file_to_test_map, cfg)
-def discover_parameters_unittest(function_name: str):
+def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]:
function_name = function_name.split("_")
if len(function_name) > 1 and function_name[-1].isdigit():
return True, "_".join(function_name[:-1]), function_name[-1]
@@ -204,9 +201,9 @@ def discover_parameters_unittest(function_name: str):
def process_test_files(
- file_to_test_map: Dict[str, List[TestsInFile]],
+ file_to_test_map: dict[str, list[TestsInFile]],
cfg: TestConfig,
-) -> Dict[str, List[FunctionCalledInTest]]:
+) -> dict[str, list[FunctionCalledInTest]]:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework
function_to_test_map = defaultdict(list)
@@ -224,8 +221,8 @@ def process_test_files(
functions_to_search = [elem.test_function for elem in functions]
for i, function in enumerate(functions_to_search):
if "[" in function:
- function_name = re.split(r"\[|\]", function)[0]
- parameters = re.split(r"\[|\]", function)[1]
+ function_name = re.split(r"[\[\]]", function)[0]
+ parameters = re.split(r"[\[\]]", function)[1]
if name.name == function_name and name.type == "function":
test_functions.add(
TestFunction(name.name, None, parameters, functions[i].test_type),
diff --git a/cli/codeflash/discovery/functions_to_optimize.py b/cli/codeflash/discovery/functions_to_optimize.py
index fc4097b4b..20a195bd0 100644
--- a/cli/codeflash/discovery/functions_to_optimize.py
+++ b/cli/codeflash/discovery/functions_to_optimize.py
@@ -7,15 +7,13 @@ from _ast import AsyncFunctionDef, ClassDef, FunctionDef
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
-from typing import Optional, Union
+from typing import TYPE_CHECKING, Optional
import git
import libcst as cst
-from libcst import CSTNode
-from libcst.metadata import CodeRange
from pydantic.dataclasses import dataclass
-from codeflash.api.cfapi import get_blacklisted_functions
+from codeflash.api.cfapi import get_blocklisted_functions
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
@@ -25,7 +23,12 @@ from codeflash.code_utils.code_utils import (
from codeflash.code_utils.git_utils import get_git_diff
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.telemetry.posthog_cf import ph
-from codeflash.verification.verification_utils import TestConfig
+
+if TYPE_CHECKING:
+ from libcst import CSTNode
+ from libcst.metadata import CodeRange
+
+ from codeflash.verification.verification_utils import TestConfig
@dataclass(frozen=True)
@@ -109,7 +112,7 @@ class FunctionParent:
type: str
-@dataclass(frozen=True, config=dict(arbitrary_types_allowed=True))
+@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class FunctionToOptimize:
"""Represents a function that is a candidate for optimization.
@@ -133,9 +136,6 @@ class FunctionToOptimize:
starting_line: Optional[int] = None
ending_line: Optional[int] = None
- # # For "BubbleSort.sorter", returns "BubbleSort"
- # # For "sorter", returns "sorter"
- # # TODO: does not support nested classes or functions
@property
def top_level_parent_name(self) -> str:
return self.function_name if not self.parents else self.parents[0].name
@@ -446,7 +446,7 @@ def filter_functions(
module_root: Path,
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
- blocklist_funcs = get_blacklisted_functions()
+ blocklist_funcs = get_blocklisted_functions()
# Remove any function that we don't want to optimize
# Ignore files with submodule path, cache the submodule paths
@@ -469,7 +469,8 @@ def filter_functions(
continue
if file_path in ignore_paths or any(
# file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path
- file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
+ file_path.startswith(str(ignore_path) + os.sep)
+ for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue
diff --git a/cli/codeflash/optimization/optimizer.py b/cli/codeflash/optimization/optimizer.py
index 2dac5affd..5b7084680 100644
--- a/cli/codeflash/optimization/optimizer.py
+++ b/cli/codeflash/optimization/optimizer.py
@@ -763,7 +763,7 @@ class Optimizer:
continue
new_test_path = Path(
- f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}"
+ f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}",
)
if injected_test is not None:
with new_test_path.open("w", encoding="utf8") as _f: