mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
assignments with if/else blocks are not modified
This commit is contained in:
parent
a8cf3ee881
commit
44d9229eea
2 changed files with 114 additions and 2 deletions
|
|
@ -29,6 +29,7 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
|||
self.assignment_order: List[str] = []
|
||||
# Track scope depth to identify global assignments
|
||||
self.scope_depth = 0
|
||||
self.if_else_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
|
|
@ -44,9 +45,20 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
|||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> Optional[bool]:
|
||||
self.if_else_depth += 1
|
||||
return True
|
||||
|
||||
def leave_If(self, original_node: cst.If) -> None:
|
||||
self.if_else_depth -= 1
|
||||
|
||||
def visit_Else(self, node: cst.Else) -> Optional[bool]:
|
||||
# Else blocks are already counted as part of the if statement
|
||||
return True
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
|
||||
# Only process global assignments (not inside functions, classes, etc.)
|
||||
if self.scope_depth == 0: # We're at module level
|
||||
if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level
|
||||
for target in node.targets:
|
||||
if isinstance(target.target, cst.Name):
|
||||
name = target.target.value
|
||||
|
|
@ -65,6 +77,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
self.new_assignment_order = new_assignment_order
|
||||
self.processed_assignments: Set[str] = set()
|
||||
self.scope_depth = 0
|
||||
self.if_else_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
|
@ -80,8 +93,19 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
self.if_else_depth += 1
|
||||
|
||||
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
|
||||
self.if_else_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_Else(self, node: cst.Else) -> None:
|
||||
# Else blocks are already counted as part of the if statement
|
||||
pass
|
||||
|
||||
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode:
|
||||
if self.scope_depth > 0:
|
||||
if self.scope_depth > 0 or self.if_else_depth > 0:
|
||||
return updated_node
|
||||
|
||||
# Check if this is a global assignment we need to replace
|
||||
|
|
|
|||
|
|
@ -2017,6 +2017,94 @@ class NewClass:
|
|||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||
code_path.write_text(original_code, encoding="utf-8")
|
||||
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_root,
|
||||
tests_project_rootdir=project_root_path,
|
||||
project_root_path=project_root_path,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=optimized_code
|
||||
)
|
||||
new_code = code_path.read_text(encoding="utf-8")
|
||||
code_path.unlink(missing_ok=True)
|
||||
assert new_code.rstrip() == expected_code.rstrip()
|
||||
|
||||
original_code = """if 2<3:
|
||||
a=4
|
||||
else:
|
||||
a=5
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
print("did noting")
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
optimized_code = """import numpy as np
|
||||
if 1<2:
|
||||
a=2
|
||||
else:
|
||||
a=3
|
||||
a = 6
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
print("did something")
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
print("Hello world")
|
||||
"""
|
||||
expected_code = """import numpy as np
|
||||
print("Hello world")
|
||||
|
||||
if 2<3:
|
||||
a=4
|
||||
else:
|
||||
a=5
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
print("did something")
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
a = 6
|
||||
"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||
code_path.write_text(original_code, encoding="utf-8")
|
||||
|
|
|
|||
Loading…
Reference in a new issue