diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 23539050e..6949bdbe1 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -9,6 +9,8 @@ from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package +from codeflash.discovery.functions_to_optimize import FunctionParent + if TYPE_CHECKING: from libcst.helpers import ModuleNameAndPackage @@ -237,3 +239,22 @@ def extract_code( ) return None, set() return edited_code, contextual_dunder_methods + + +def find_preexisting_objects(source_code: str): + """Find all preexisting functions, classes or class methods in the source code""" + preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] + try: + module_node: ast.Module = ast.parse(source_code) + except SyntaxError: + logging.exception("find_preexisting_objects - Syntax error while parsing code") + return preexisting_objects + for node in module_node.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + preexisting_objects.append((node.name, [])) + elif isinstance(node, ast.ClassDef): + preexisting_objects.append((node.name, [])) + for cnode in node.body: + if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): + preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")])) + return preexisting_objects diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index d74e64b5d..ab1a434f0 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -20,7 +20,11 @@ from codeflash.api.aiservice import ( OptimizedCandidate, ) from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code +from codeflash.code_utils.code_extractor import ( + add_needed_imports_from_module, + extract_code, + find_preexisting_objects, +) from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( get_run_tmp_file, @@ -606,11 +610,6 @@ class Optimizer: ) if code_to_optimize is None: return Failure("Could not find function to optimize.") - preexisting_objects: list[tuple[str, list[FunctionParent]]] = [ - (name, [FunctionParent(name=class_name, type="ClassDef")]) - for class_name, name in contextual_dunder_methods - ] - preexisting_objects.append((function_to_optimize.function_name, function_to_optimize.parents)) ( helper_code, helper_functions, @@ -653,14 +652,7 @@ class Optimizer: project_root, helper_functions, ) - preexisting_objects.extend( - [ - (qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")])) - if len(qualified_name_list := fn.qualified_name.split(".")) > 1 - else (qualified_name_list[-1], []) - for fn in helper_functions - ], - ) + preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers) contextual_dunder_methods.update(helper_dunder_methods) return Success( CodeOptimizationContext( diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index f69fdb54b..f0b3738a5 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -6,7 +6,6 @@ from argparse import Namespace from collections import defaultdict from pathlib import Path -import pytest from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize from codeflash.optimization.optimizer import Optimizer @@ -946,7 +945,6 @@ def test_test_libcst_code_replacement13() -> None: assert new_code == original_code -@pytest.mark.skip() def test_different_class_code_replacement(): original_code = """from __future__ import annotations import sys @@ -1118,6 +1116,8 @@ class TestResults(BaseModel): ("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]), ("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]), ("TestType", []), + ("TestResults", []), + ("to_name", [FunctionParent(name="TestType", type="ClassDef")]), ] contextual_functions = { @@ -1152,6 +1152,7 @@ class TestResults(BaseModel): contextual_functions=contextual_functions, project_root_path=str(Path(__file__).parent.resolve()), ) + helper_functions_by_module_abspath = defaultdict(set) for helper_function in helper_functions: if helper_function.jedi_definition.type != "class": @@ -1173,7 +1174,83 @@ class TestResults(BaseModel): project_root_path=str(Path(__file__).parent.resolve()), ) - print("hi") + assert ( + new_code + == """from __future__ import annotations +import sys +from codeflash.verification.comparator import comparator +from enum import Enum +from pydantic import BaseModel +from typing import Iterator + +class TestType(Enum): + EXISTING_UNIT_TEST = 1 + INSPIRED_REGRESSION = 2 + GENERATED_REGRESSION = 3 + REPLAY_TEST = 4 + + def to_name(self) -> str: + names = { + TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests", + TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests", + TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests", + TestType.REPLAY_TEST: "⏪ Replay Tests", + } + return names[self] + +class TestResults(BaseModel): + def __iter__(self) -> Iterator[FunctionTestInvocation]: + return iter(self.test_results) + def __len__(self) -> int: + return len(self.test_results) + def __getitem__(self, index: int) -> FunctionTestInvocation: + return self.test_results[index] + def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: + self.test_results[index] = value + def __delitem__(self, index: int) -> None: + del self.test_results[index] + def __contains__(self, value: FunctionTestInvocation) -> bool: + return value in self.test_results + def __bool__(self) -> bool: + return bool(self.test_results) + def __eq__(self, other: object) -> bool: + # Unordered comparison + if type(self) != type(other): + return False + if len(self) != len(other): + return False + original_recursion_limit = sys.getrecursionlimit() + for test_result in self: + other_test_result = other.get_by_id(test_result.id) + if other_test_result is None: + return False + + if original_recursion_limit < 5000: + sys.setrecursionlimit(5000) + if ( + test_result.file_name != other_test_result.file_name + or test_result.did_pass != other_test_result.did_pass + or test_result.runtime != other_test_result.runtime + or test_result.test_framework != other_test_result.test_framework + or test_result.test_type != other_test_result.test_type + or not comparator( + test_result.return_value, + other_test_result.return_value, + ) + ): + sys.setrecursionlimit(original_recursion_limit) + return False + sys.setrecursionlimit(original_recursion_limit) + return True + + def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: + report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType} + for test_result in self.test_results: + if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested: + key = "passed" if test_result.did_pass else "failed" + report[test_result.test_type][key] += 1 + return report""" + ) def test_code_replacement_type_annotation():