refactoring to run only for code replacement before PR/behavior instead of code context extraction

This commit is contained in:
aseembits93 2025-04-30 19:46:29 -07:00
parent 3da7b739de
commit 28596b7b55
3 changed files with 64 additions and 34 deletions

View file

@ -235,15 +235,7 @@ def delete___future___aliased_imports(module_code: str) -> str:
return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code
def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
src_path: Path,
dst_path: Path,
project_root: Path,
helper_functions: list[FunctionSource] | None = None,
helper_functions_fqn: set[str] | None = None,
) -> str:
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
non_assignment_global_statements = extract_global_statements(src_module_code)
# Find the last import line in target
@ -272,7 +264,18 @@ def add_needed_imports_from_module(
transformed_module = original_module.visit(transformer)
dst_module_code = transformed_module.code
return dst_module_code
def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
src_path: Path,
dst_path: Path,
project_root: Path,
helper_functions: list[FunctionSource] | None = None,
helper_functions_fqn: set[str] | None = None,
) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it."""
src_module_code = delete___future___aliased_imports(src_module_code)
if not helper_functions_fqn:

View file

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
@ -220,7 +220,8 @@ def replace_function_definitions_in_module(
)
if is_zero_diff(source_code, new_code):
return False
module_abspath.write_text(new_code, encoding="utf8")
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
return True

View file

@ -804,7 +804,8 @@ class MainClass:
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()"""
return HelperClass(self.name).helper_method()
"""
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@ -1665,6 +1666,8 @@ print("Hello world")
def test_global_reassignment() -> None:
original_code = """a=1
print("Hello world")
def some_fn():
print("did noting")
class NewClass:
def __init__(self, name):
self.name = name
@ -1672,44 +1675,67 @@ class NewClass:
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
"""
optim_code = """import numpy as np
"""
optimized_code = """import numpy as np
def some_fn():
a=np.zeros(10)
print("did something")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
w = np.array([1,2,3])
return "I am new"
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
a=2
print("Hello world")
"""
modified_code = """import numpy as np
"""
expected_code = """import numpy as np
print("Hello world")
a=2
print("Hello world")
def some_fn():
a=np.zeros(10)
print("did something")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
w = np.array([1,2,3])
return "I am new"
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
"""
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
project_root_path = (Path(__file__).parent / "..").resolve()
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
assert new_code == modified_code
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
assert new_code.rstrip() == expected_code.rstrip()