From 7b33e8b7f6b467658a128aeb5d23da608356ddda Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 24 Jan 2026 06:29:39 -0500 Subject: [PATCH] refactor: smarter placement of global assignments based on dependencies Assignments that don't reference module-level definitions are now placed right after imports. Only assignments that reference classes/functions are placed after those definitions to prevent NameError. --- codeflash/code_utils/code_extractor.py | 81 ++++++++++++++++++----- tests/test_code_context_extractor.py | 12 ++-- tests/test_code_replacement.py | 14 +--- tests/test_get_read_writable_code.py | 4 -- tests/test_multi_file_code_replacement.py | 9 +-- 5 files changed, 74 insertions(+), 46 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index a7dd08fe9..6ddfe763a 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -115,6 +115,21 @@ class GlobalFunctionTransformer(cst.CSTTransformer): return updated_node.with_changes(body=new_statements) +def collect_referenced_names(node: cst.CSTNode) -> set[str]: + """Collect all names referenced in a CST node using recursive traversal.""" + names: set[str] = set() + + def _collect(n: cst.CSTNode) -> None: + if isinstance(n, cst.Name): + names.add(n.value) + # Recursively process all children + for child in n.children: + _collect(child) + + _collect(node) + return names + + class GlobalAssignmentCollector(cst.CSTVisitor): """Collects all global assignment statements.""" @@ -274,37 +289,69 @@ class GlobalAssignmentTransformer(cst.CSTTransformer): # Find assignments to append assignments_to_append = [ - self.new_assignments[name] + (name, self.new_assignments[name]) for name in self.new_assignment_order if name not in self.processed_assignments and name in self.new_assignments ] - if assignments_to_append: - # Start after imports, then advance past class/function definitions - # to ensure assignments can reference any classes defined in the module + if not assignments_to_append: + return updated_node.with_changes(body=new_statements) + + # Collect all class and function names defined in the module + # These are the names that assignments might reference + module_defined_names: set[str] = set() + for stmt in new_statements: + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + module_defined_names.add(stmt.name.value) + + # Partition assignments: those that reference module definitions go at the end, + # those that don't can go right after imports + assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = [] + + for name, assignment in assignments_to_append: + # Get the value being assigned + if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None: + value_node = assignment.value + else: + # No value to analyze, safe to place after imports + assignments_after_imports.append((name, assignment)) + continue + + # Collect names referenced in the assignment value + referenced_names = collect_referenced_names(value_node) + + # Check if any referenced names are module-level definitions + if referenced_names & module_defined_names: + # This assignment references a class/function, place it after definitions + assignments_after_definitions.append((name, assignment)) + else: + # Safe to place right after imports + assignments_after_imports.append((name, assignment)) + + # Insert assignments that don't depend on module definitions right after imports + if assignments_after_imports: insert_index = find_insertion_index_after_imports(updated_node) + assignment_lines = [ + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) + for _, assignment in assignments_after_imports + ] + new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) + + # Insert assignments that depend on module definitions after all class/function definitions + if assignments_after_definitions: + # Find the position after the last function or class definition + insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements)) for i, stmt in enumerate(new_statements): if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): insert_index = i + 1 assignment_lines = [ cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for assignment in assignments_to_append + for _, assignment in assignments_after_definitions ] - new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) - # Add a blank line after the last assignment if needed - after_index = insert_index + len(assignment_lines) - if after_index < len(new_statements): - next_stmt = new_statements[after_index] - # If there's no empty line, add one - has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines) - if not has_empty: - new_statements[after_index] = next_stmt.with_changes( - leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines] - ) - return updated_node.with_changes(body=new_statements) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 57a951660..769e10a8c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -2975,11 +2975,11 @@ class MyClass: return cached_helper(5) """ - # Global assignments are now inserted AFTER class/function definitions - # to ensure they can reference classes defined in the module expected = """\ from typing import Any +_LOCAL_CACHE: dict[str, int] = {} + class MyClass: def method(self): return cached_helper(5) @@ -2992,8 +2992,6 @@ def cached_helper(x: int) -> int: def regular_helper(): return "regular" - -_LOCAL_CACHE: dict[str, int] = {} """ result = add_global_assignments(source_code, destination_code) @@ -3111,11 +3109,11 @@ def handle_message(kind): return "reply" """ - # Global statements (function calls) should be inserted AFTER all class/function - # definitions to ensure they can reference any function defined in the module expected = """\ import enum +_factories = {} + class MessageKind(enum.StrEnum): ASK = "ask" REPLY = "reply" @@ -3129,8 +3127,6 @@ def handle_message(kind): def _register(kind, factory): _factories[kind] = factory -_factories = {} - _register(MessageKind.ASK, lambda: "ask handler") diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index f836f3d40..da83146a8 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2116,12 +2116,9 @@ class NewClass: print("Hello world") ``` """ - # Global assignments are now inserted AFTER class/function definitions - # to ensure they can reference any classes defined in the module. - # This prevents NameError when LLM-generated optimizations like - # `_HANDLERS = {MessageKind.XXX: ...}` reference classes. expected_code = """import numpy as np +a = 6 if 2<3: a=4 else: @@ -2143,8 +2140,6 @@ class NewClass: return "I am still old" def new_function2(value): return cst.ensure_type(value, str) - -a = 6 """ code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path.write_text(original_code, encoding="utf-8") @@ -3371,9 +3366,6 @@ def hydrate_input_text_actions_with_field_names( return updated_actions_by_task ``` ''' - # Global assignments are now inserted AFTER class/function definitions - # to ensure they can reference any classes defined in the module. - # This prevents NameError when LLM-generated optimizations reference classes. expected = '''""" Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions. """ @@ -3388,6 +3380,8 @@ from skyvern.forge.sdk.prompting import PromptEngine from skyvern.webeye.actions.actions import ActionType import re +_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+") + LOG = structlog.get_logger(__name__) # Initialize prompt engine @@ -3441,8 +3435,6 @@ def hydrate_input_text_actions_with_field_names( updated_actions_by_task[task_id] = updated_actions return updated_actions_by_task - -_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+") ''' func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file) diff --git a/tests/test_get_read_writable_code.py b/tests/test_get_read_writable_code.py index f08182427..952479d3a 100644 --- a/tests/test_get_read_writable_code.py +++ b/tests/test_get_read_writable_code.py @@ -218,10 +218,6 @@ def test_no_targets_found() -> None: def target(self): pass """ - # Nested class methods (MyClass.Inner.target) aren't directly targetable, - # but the outer class is kept when the qualified name starts with it. - # This is because the dependency tracking marks "MyClass" as used when it - # sees "MyClass.Inner.target" as a target function. result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) expected = dedent(""" class MyClass: diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index a1367be70..2d1f22509 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -124,15 +124,14 @@ def _get_string_usage(text: str) -> Usage: helper_file.unlink(missing_ok=True) main_file.unlink(missing_ok=True) - - # Global assignments are now inserted AFTER class/function definitions - # to prevent NameError when they reference classes or functions. - # See commit 50fba096 for details. + expected_helper = """import re from collections.abc import Sequence from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent +_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'} + _TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: @@ -159,8 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: tokens += len(part.data) return tokens - -_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'} """ assert new_code.rstrip() == original_main.rstrip() # No Change