From 28596b7b55aaa5ac257408337c98cb84750feff5 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Wed, 30 Apr 2025 19:46:29 -0700 Subject: [PATCH] refactoring to run only for code replacement before PR/behavior instead of code context extraction --- codeflash/code_utils/code_extractor.py | 21 ++++---- codeflash/code_utils/code_replacer.py | 5 +- tests/test_code_replacement.py | 72 ++++++++++++++++++-------- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 56dcf83fa..96b6dd845 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -235,15 +235,7 @@ def delete___future___aliased_imports(module_code: str) -> str: return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code -def add_needed_imports_from_module( - src_module_code: str, - dst_module_code: str, - src_path: Path, - dst_path: Path, - project_root: Path, - helper_functions: list[FunctionSource] | None = None, - helper_functions_fqn: set[str] | None = None, -) -> str: +def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: non_assignment_global_statements = extract_global_statements(src_module_code) # Find the last import line in target @@ -272,7 +264,18 @@ def add_needed_imports_from_module( transformed_module = original_module.visit(transformer) dst_module_code = transformed_module.code + return dst_module_code + +def add_needed_imports_from_module( + src_module_code: str, + dst_module_code: str, + src_path: Path, + dst_path: Path, + project_root: Path, + helper_functions: list[FunctionSource] | None = None, + helper_functions_fqn: set[str] | None = None, +) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" src_module_code = delete___future___aliased_imports(src_module_code) if not helper_functions_fqn: diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ad37bfbd2..ccb935f42 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar import libcst as cst from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import add_needed_imports_from_module +from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -220,7 +220,8 @@ def replace_function_definitions_in_module( ) if is_zero_diff(source_code, new_code): return False - module_abspath.write_text(new_code, encoding="utf8") + code_with_global_assignments = add_global_assignments(optimized_code, new_code) + module_abspath.write_text(code_with_global_assignments, encoding="utf8") return True diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 5e166cda5..aa7132ff0 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -804,7 +804,8 @@ class MainClass: self.name = name def main_method(self): - return HelperClass(self.name).helper_method()""" + return HelperClass(self.name).helper_method() +""" file_path = Path(__file__).resolve() func_top_optimize = FunctionToOptimize( function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")] @@ -1665,6 +1666,8 @@ print("Hello world") def test_global_reassignment() -> None: original_code = """a=1 print("Hello world") +def some_fn(): + print("did noting") class NewClass: def __init__(self, name): self.name = name @@ -1672,44 +1675,67 @@ class NewClass: return "I am still old" def new_function2(value): return cst.ensure_type(value, str) -""" - - optim_code = """import numpy as np + """ + optimized_code = """import numpy as np +def some_fn(): + a=np.zeros(10) + print("did something") class NewClass: def __init__(self, name): self.name = name def __call__(self, value): - w = np.array([1,2,3]) - return "I am new" + return "I am still old" def new_function2(value): return cst.ensure_type(value, str) a=2 print("Hello world") -""" - - modified_code = """import numpy as np - + """ + expected_code = """import numpy as np print("Hello world") + a=2 print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") class NewClass: def __init__(self, name): self.name = name def __call__(self, value): - w = np.array([1,2,3]) - return "I am new" + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" def new_function2(value): return cst.ensure_type(value, str) """ - - function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"] - preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) - new_code: str = replace_functions_and_add_imports( - source_code=original_code, - function_names=function_names, - optimized_code=optim_code, - module_abspath=Path(__file__).resolve(), - preexisting_objects=preexisting_objects, - project_root_path=Path(__file__).resolve().parent.resolve(), + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", ) - assert new_code == modified_code \ No newline at end of file + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip()