refactor: smarter placement of global assignments based on dependencies

Assignments that don't reference module-level definitions are now placed
right after imports. Only assignments that reference classes/functions
are placed after those definitions to prevent NameError.
This commit is contained in:
Kevin Turcios 2026-01-24 06:29:39 -05:00
parent 257c5f2b8f
commit 7b33e8b7f6
5 changed files with 74 additions and 46 deletions

View file

@ -115,6 +115,21 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
return updated_node.with_changes(body=new_statements) return updated_node.with_changes(body=new_statements)
def collect_referenced_names(node: cst.CSTNode) -> set[str]:
"""Collect all names referenced in a CST node using recursive traversal."""
names: set[str] = set()
def _collect(n: cst.CSTNode) -> None:
if isinstance(n, cst.Name):
names.add(n.value)
# Recursively process all children
for child in n.children:
_collect(child)
_collect(node)
return names
class GlobalAssignmentCollector(cst.CSTVisitor): class GlobalAssignmentCollector(cst.CSTVisitor):
"""Collects all global assignment statements.""" """Collects all global assignment statements."""
@ -274,37 +289,69 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
# Find assignments to append # Find assignments to append
assignments_to_append = [ assignments_to_append = [
self.new_assignments[name] (name, self.new_assignments[name])
for name in self.new_assignment_order for name in self.new_assignment_order
if name not in self.processed_assignments and name in self.new_assignments if name not in self.processed_assignments and name in self.new_assignments
] ]
if assignments_to_append: if not assignments_to_append:
# Start after imports, then advance past class/function definitions return updated_node.with_changes(body=new_statements)
# to ensure assignments can reference any classes defined in the module
# Collect all class and function names defined in the module
# These are the names that assignments might reference
module_defined_names: set[str] = set()
for stmt in new_statements:
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
module_defined_names.add(stmt.name.value)
# Partition assignments: those that reference module definitions go at the end,
# those that don't can go right after imports
assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
for name, assignment in assignments_to_append:
# Get the value being assigned
if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None:
value_node = assignment.value
else:
# No value to analyze, safe to place after imports
assignments_after_imports.append((name, assignment))
continue
# Collect names referenced in the assignment value
referenced_names = collect_referenced_names(value_node)
# Check if any referenced names are module-level definitions
if referenced_names & module_defined_names:
# This assignment references a class/function, place it after definitions
assignments_after_definitions.append((name, assignment))
else:
# Safe to place right after imports
assignments_after_imports.append((name, assignment))
# Insert assignments that don't depend on module definitions right after imports
if assignments_after_imports:
insert_index = find_insertion_index_after_imports(updated_node) insert_index = find_insertion_index_after_imports(updated_node)
assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for _, assignment in assignments_after_imports
]
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
# Insert assignments that depend on module definitions after all class/function definitions
if assignments_after_definitions:
# Find the position after the last function or class definition
insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements))
for i, stmt in enumerate(new_statements): for i, stmt in enumerate(new_statements):
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
insert_index = i + 1 insert_index = i + 1
assignment_lines = [ assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for assignment in assignments_to_append for _, assignment in assignments_after_definitions
] ]
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
# Add a blank line after the last assignment if needed
after_index = insert_index + len(assignment_lines)
if after_index < len(new_statements):
next_stmt = new_statements[after_index]
# If there's no empty line, add one
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
if not has_empty:
new_statements[after_index] = next_stmt.with_changes(
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
)
return updated_node.with_changes(body=new_statements) return updated_node.with_changes(body=new_statements)

View file

@ -2975,11 +2975,11 @@ class MyClass:
return cached_helper(5) return cached_helper(5)
""" """
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference classes defined in the module
expected = """\ expected = """\
from typing import Any from typing import Any
_LOCAL_CACHE: dict[str, int] = {}
class MyClass: class MyClass:
def method(self): def method(self):
return cached_helper(5) return cached_helper(5)
@ -2992,8 +2992,6 @@ def cached_helper(x: int) -> int:
def regular_helper(): def regular_helper():
return "regular" return "regular"
_LOCAL_CACHE: dict[str, int] = {}
""" """
result = add_global_assignments(source_code, destination_code) result = add_global_assignments(source_code, destination_code)
@ -3111,11 +3109,11 @@ def handle_message(kind):
return "reply" return "reply"
""" """
# Global statements (function calls) should be inserted AFTER all class/function
# definitions to ensure they can reference any function defined in the module
expected = """\ expected = """\
import enum import enum
_factories = {}
class MessageKind(enum.StrEnum): class MessageKind(enum.StrEnum):
ASK = "ask" ASK = "ask"
REPLY = "reply" REPLY = "reply"
@ -3129,8 +3127,6 @@ def handle_message(kind):
def _register(kind, factory): def _register(kind, factory):
_factories[kind] = factory _factories[kind] = factory
_factories = {}
_register(MessageKind.ASK, lambda: "ask handler") _register(MessageKind.ASK, lambda: "ask handler")

View file

@ -2116,12 +2116,9 @@ class NewClass:
print("Hello world") print("Hello world")
``` ```
""" """
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference any classes defined in the module.
# This prevents NameError when LLM-generated optimizations like
# `_HANDLERS = {MessageKind.XXX: ...}` reference classes.
expected_code = """import numpy as np expected_code = """import numpy as np
a = 6
if 2<3: if 2<3:
a=4 a=4
else: else:
@ -2143,8 +2140,6 @@ class NewClass:
return "I am still old" return "I am still old"
def new_function2(value): def new_function2(value):
return cst.ensure_type(value, str) return cst.ensure_type(value, str)
a = 6
""" """
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8") code_path.write_text(original_code, encoding="utf-8")
@ -3371,9 +3366,6 @@ def hydrate_input_text_actions_with_field_names(
return updated_actions_by_task return updated_actions_by_task
``` ```
''' '''
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference any classes defined in the module.
# This prevents NameError when LLM-generated optimizations reference classes.
expected = '''""" expected = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions. Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
""" """
@ -3388,6 +3380,8 @@ from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType from skyvern.webeye.actions.actions import ActionType
import re import re
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
LOG = structlog.get_logger(__name__) LOG = structlog.get_logger(__name__)
# Initialize prompt engine # Initialize prompt engine
@ -3441,8 +3435,6 @@ def hydrate_input_text_actions_with_field_names(
updated_actions_by_task[task_id] = updated_actions updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task return updated_actions_by_task
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
''' '''
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file) func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)

View file

@ -218,10 +218,6 @@ def test_no_targets_found() -> None:
def target(self): def target(self):
pass pass
""" """
# Nested class methods (MyClass.Inner.target) aren't directly targetable,
# but the outer class is kept when the qualified name starts with it.
# This is because the dependency tracking marks "MyClass" as used when it
# sees "MyClass.Inner.target" as a target function.
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"}) result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
expected = dedent(""" expected = dedent("""
class MyClass: class MyClass:

View file

@ -124,15 +124,14 @@ def _get_string_usage(text: str) -> Usage:
helper_file.unlink(missing_ok=True) helper_file.unlink(missing_ok=True)
main_file.unlink(missing_ok=True) main_file.unlink(missing_ok=True)
# Global assignments are now inserted AFTER class/function definitions
# to prevent NameError when they reference classes or functions.
# See commit 50fba096 for details.
expected_helper = """import re expected_helper = """import re
from collections.abc import Sequence from collections.abc import Sequence
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') _TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
@ -159,8 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
tokens += len(part.data) tokens += len(part.data)
return tokens return tokens
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
""" """
assert new_code.rstrip() == original_main.rstrip() # No Change assert new_code.rstrip() == original_main.rstrip() # No Change