mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
tests work now
This commit is contained in:
parent
18c989537e
commit
abbaec3b5e
2 changed files with 143 additions and 37 deletions
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
|
|
@ -20,20 +20,106 @@ if TYPE_CHECKING:
|
|||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
class ImportCollector(cst.CSTVisitor):
|
||||
"""Visitor that collects all import statements in a module."""
|
||||
class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||
"""Collects all global assignment statements."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.imports = []
|
||||
self.assignments: Dict[str, cst.Assign] = {}
|
||||
self.assignment_order: List[str] = []
|
||||
# Track scope depth to identify global assignments
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_Import(self, node: cst.Import) -> None:
|
||||
self.imports.append(node)
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
self.imports.append(node)
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
|
||||
# Only process global assignments (not inside functions, classes, etc.)
|
||||
if self.scope_depth == 0: # We're at module level
|
||||
for target in node.targets:
|
||||
if isinstance(target.target, cst.Name):
|
||||
name = target.target.value
|
||||
self.assignments[name] = node
|
||||
if name not in self.assignment_order:
|
||||
self.assignment_order.append(name)
|
||||
return True
|
||||
|
||||
|
||||
class GlobalAssignmentTransformer(cst.CSTTransformer):
|
||||
"""Transforms global assignments in the original file with those from the new file."""
|
||||
|
||||
def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]):
|
||||
super().__init__()
|
||||
self.new_assignments = new_assignments
|
||||
self.new_assignment_order = new_assignment_order
|
||||
self.processed_assignments: Set[str] = set()
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode:
|
||||
if self.scope_depth > 0:
|
||||
return updated_node
|
||||
|
||||
# Check if this is a global assignment we need to replace
|
||||
for target in original_node.targets:
|
||||
if isinstance(target.target, cst.Name):
|
||||
name = target.target.value
|
||||
if name in self.new_assignments:
|
||||
self.processed_assignments.add(name)
|
||||
return self.new_assignments[name]
|
||||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
# Find assignments to append
|
||||
assignments_to_append = []
|
||||
for name in self.new_assignment_order:
|
||||
if name not in self.processed_assignments and name in self.new_assignments:
|
||||
assignments_to_append.append(self.new_assignments[name])
|
||||
|
||||
if assignments_to_append:
|
||||
# Add a blank line before appending new assignments if needed
|
||||
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
|
||||
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
|
||||
new_statements.pop() # Remove the Pass statement but keep the empty line
|
||||
|
||||
# Add the new assignments
|
||||
for assignment in assignments_to_append:
|
||||
new_statements.append(
|
||||
cst.SimpleStatementLine(
|
||||
[assignment],
|
||||
leading_lines=[cst.EmptyLine()]
|
||||
)
|
||||
)
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
class GlobalStatementCollector(cst.CSTVisitor):
|
||||
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
|
||||
|
|
@ -63,7 +149,7 @@ class GlobalStatementCollector(cst.CSTVisitor):
|
|||
if not self.in_function_or_class:
|
||||
for statement in node.body:
|
||||
# Skip imports
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom)):
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
|
||||
self.global_statements.append(node)
|
||||
break
|
||||
|
||||
|
|
@ -130,28 +216,6 @@ def find_last_import_line(target_code: str) -> int:
|
|||
module.visit(finder)
|
||||
return finder.last_import_line
|
||||
|
||||
|
||||
def merge_globals(source_code: str, target_code: str) -> str:
|
||||
"""Merge global statements from source into target just after imports."""
|
||||
# Extract global statements from source
|
||||
global_statements = extract_global_statements(source_code)
|
||||
|
||||
# Find the last import line in target
|
||||
last_import_line = find_last_import_line(target_code)
|
||||
|
||||
# Parse the target code
|
||||
target_module = cst.parse_module(target_code)
|
||||
|
||||
# Create transformer to insert global statements
|
||||
transformer = ImportInserter(global_statements, last_import_line)
|
||||
|
||||
# Apply transformation
|
||||
modified_module = target_module.visit(transformer)
|
||||
|
||||
# Return the modified code
|
||||
return modified_module.code
|
||||
|
||||
|
||||
class FutureAliasedImportTransformer(cst.CSTTransformer):
|
||||
def leave_ImportFrom(
|
||||
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
|
|
@ -180,7 +244,7 @@ def add_needed_imports_from_module(
|
|||
helper_functions: list[FunctionSource] | None = None,
|
||||
helper_functions_fqn: set[str] | None = None,
|
||||
) -> str:
|
||||
global_statements = extract_global_statements(src_module_code)
|
||||
non_assignment_global_statements = extract_global_statements(src_module_code)
|
||||
|
||||
# Find the last import line in target
|
||||
last_import_line = find_last_import_line(dst_module_code)
|
||||
|
|
@ -188,12 +252,27 @@ def add_needed_imports_from_module(
|
|||
# Parse the target code
|
||||
target_module = cst.parse_module(dst_module_code)
|
||||
|
||||
# Create transformer to insert global statements
|
||||
transformer = ImportInserter(global_statements, last_import_line)
|
||||
# Create transformer to insert non_assignment_global_statements
|
||||
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
|
||||
#
|
||||
# # Apply transformation
|
||||
modified_module = target_module.visit(transformer)
|
||||
dst_module_code = modified_module.code
|
||||
|
||||
# Parse the code
|
||||
original_module = cst.parse_module(dst_module_code)
|
||||
new_module = cst.parse_module(src_module_code)
|
||||
|
||||
# Collect assignments from the new file
|
||||
new_collector = GlobalAssignmentCollector()
|
||||
new_module.visit(new_collector)
|
||||
|
||||
# Transform the original file
|
||||
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
|
||||
transformed_module = original_module.visit(transformer)
|
||||
|
||||
dst_module_code = transformed_module.code
|
||||
|
||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||
src_module_code = delete___future___aliased_imports(src_module_code)
|
||||
if not helper_functions_fqn:
|
||||
|
|
|
|||
|
|
@ -789,7 +789,22 @@ class MainClass:
|
|||
|
||||
|
||||
def test_code_replacement10() -> None:
|
||||
get_code_output = 'from __future__ import annotations\nimport os\n\nos.environ["CODEFLASH_API_KEY"] = "cf-test-key"\nclass HelperClass:\n def __init__(self, name):\n self.name = name\n\n def helper_method(self):\n return self.name\n\n\nclass MainClass:\n def __init__(self, name):\n self.name = name\n\n def main_method(self):\n return HelperClass(self.name).helper_method()\n'
|
||||
get_code_output = """from __future__ import annotations
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()"""
|
||||
file_path = Path(__file__).resolve()
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
|
|
@ -1636,6 +1651,17 @@ print("Hello world")
|
|||
print("Hello world")
|
||||
"""
|
||||
|
||||
modified_code = """print("Hello world")
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
|
|
@ -1646,7 +1672,7 @@ print("Hello world")
|
|||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == original_code
|
||||
assert new_code == modified_code
|
||||
|
||||
def test_global_reassignment() -> None:
|
||||
original_code = """a=1
|
||||
|
|
@ -1674,6 +1700,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
modified_code = """import numpy as np
|
||||
|
||||
print("Hello world")
|
||||
a=2
|
||||
print("Hello world")
|
||||
|
|
|
|||
Loading…
Reference in a new issue