tests work now

This commit is contained in:
aseembits93 2025-04-30 18:14:00 -07:00
parent 18c989537e
commit abbaec3b5e
2 changed files with 143 additions and 37 deletions

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import ast
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, Optional, Set
import libcst as cst
import libcst.matchers as m
@ -20,20 +20,106 @@ if TYPE_CHECKING:
from typing import List, Union
class ImportCollector(cst.CSTVisitor):
"""Visitor that collects all import statements in a module."""
class GlobalAssignmentCollector(cst.CSTVisitor):
"""Collects all global assignment statements."""
def __init__(self):
super().__init__()
self.imports = []
self.assignments: Dict[str, cst.Assign] = {}
self.assignment_order: List[str] = []
# Track scope depth to identify global assignments
self.scope_depth = 0
def visit_Import(self, node: cst.Import) -> None:
self.imports.append(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope_depth += 1
return True
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
self.imports.append(node)
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.scope_depth -= 1
def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
# Only process global assignments (not inside functions, classes, etc.)
if self.scope_depth == 0: # We're at module level
for target in node.targets:
if isinstance(target.target, cst.Name):
name = target.target.value
self.assignments[name] = node
if name not in self.assignment_order:
self.assignment_order.append(name)
return True
class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""
def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]):
super().__init__()
self.new_assignments = new_assignments
self.new_assignment_order = new_assignment_order
self.processed_assignments: Set[str] = set()
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.scope_depth -= 1
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.scope_depth -= 1
return updated_node
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode:
if self.scope_depth > 0:
return updated_node
# Check if this is a global assignment we need to replace
for target in original_node.targets:
if isinstance(target.target, cst.Name):
name = target.target.value
if name in self.new_assignments:
self.processed_assignments.add(name)
return self.new_assignments[name]
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
# Find assignments to append
assignments_to_append = []
for name in self.new_assignment_order:
if name not in self.processed_assignments and name in self.new_assignments:
assignments_to_append.append(self.new_assignments[name])
if assignments_to_append:
# Add a blank line before appending new assignments if needed
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
new_statements.pop() # Remove the Pass statement but keep the empty line
# Add the new assignments
for assignment in assignments_to_append:
new_statements.append(
cst.SimpleStatementLine(
[assignment],
leading_lines=[cst.EmptyLine()]
)
)
return updated_node.with_changes(body=new_statements)
class GlobalStatementCollector(cst.CSTVisitor):
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
@ -63,7 +149,7 @@ class GlobalStatementCollector(cst.CSTVisitor):
if not self.in_function_or_class:
for statement in node.body:
# Skip imports
if not isinstance(statement, (cst.Import, cst.ImportFrom)):
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
self.global_statements.append(node)
break
@ -130,28 +216,6 @@ def find_last_import_line(target_code: str) -> int:
module.visit(finder)
return finder.last_import_line
def merge_globals(source_code: str, target_code: str) -> str:
"""Merge global statements from source into target just after imports."""
# Extract global statements from source
global_statements = extract_global_statements(source_code)
# Find the last import line in target
last_import_line = find_last_import_line(target_code)
# Parse the target code
target_module = cst.parse_module(target_code)
# Create transformer to insert global statements
transformer = ImportInserter(global_statements, last_import_line)
# Apply transformation
modified_module = target_module.visit(transformer)
# Return the modified code
return modified_module.code
class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
@ -180,7 +244,7 @@ def add_needed_imports_from_module(
helper_functions: list[FunctionSource] | None = None,
helper_functions_fqn: set[str] | None = None,
) -> str:
global_statements = extract_global_statements(src_module_code)
non_assignment_global_statements = extract_global_statements(src_module_code)
# Find the last import line in target
last_import_line = find_last_import_line(dst_module_code)
@ -188,12 +252,27 @@ def add_needed_imports_from_module(
# Parse the target code
target_module = cst.parse_module(dst_module_code)
# Create transformer to insert global statements
transformer = ImportInserter(global_statements, last_import_line)
# Create transformer to insert non_assignment_global_statements
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
#
# # Apply transformation
modified_module = target_module.visit(transformer)
dst_module_code = modified_module.code
# Parse the code
original_module = cst.parse_module(dst_module_code)
new_module = cst.parse_module(src_module_code)
# Collect assignments from the new file
new_collector = GlobalAssignmentCollector()
new_module.visit(new_collector)
# Transform the original file
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)
dst_module_code = transformed_module.code
"""Add all needed and used source module code imports to the destination module code, and return it."""
src_module_code = delete___future___aliased_imports(src_module_code)
if not helper_functions_fqn:

View file

@ -789,7 +789,22 @@ class MainClass:
def test_code_replacement10() -> None:
get_code_output = 'from __future__ import annotations\nimport os\n\nos.environ["CODEFLASH_API_KEY"] = "cf-test-key"\nclass HelperClass:\n def __init__(self, name):\n self.name = name\n\n def helper_method(self):\n return self.name\n\n\nclass MainClass:\n def __init__(self, name):\n self.name = name\n\n def main_method(self):\n return HelperClass(self.name).helper_method()\n'
get_code_output = """from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()"""
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@ -1636,6 +1651,17 @@ print("Hello world")
print("Hello world")
"""
modified_code = """print("Hello world")
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
"""
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
@ -1646,7 +1672,7 @@ print("Hello world")
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == original_code
assert new_code == modified_code
def test_global_reassignment() -> None:
original_code = """a=1
@ -1674,6 +1700,7 @@ print("Hello world")
"""
modified_code = """import numpy as np
print("Hello world")
a=2
print("Hello world")