fix: insert global statements after function definitions to prevent NameError

When LLM-generated optimizations include module-level function calls like
`_register(MessageKind.ASK, ...)`, they were being inserted right after
imports, BEFORE the function definition they reference, causing NameError
at module load time.

Changes:
- Add GlobalStatementTransformer to append global statements at module end
- Reorder transformations: functions → assignments → statements
- Remove unused ImportInserter class
- Update test expectations to reflect new placement behavior
This commit is contained in:
Kevin Turcios 2026-01-24 02:09:38 -05:00
parent abfa640578
commit 50fba096f7
3 changed files with 143 additions and 59 deletions

View file

@ -308,6 +308,39 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node.with_changes(body=new_statements)
class GlobalStatementTransformer(cst.CSTTransformer):
"""Transformer that appends global statements at the end of the module.
This ensures that global statements (like function calls at module level) are placed
after all functions, classes, and assignments they might reference, preventing NameError
at module load time.
This transformer should be run LAST after GlobalFunctionTransformer and
GlobalAssignmentTransformer have already added their content.
"""
def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None:
super().__init__()
self.global_statements = global_statements
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
if not self.global_statements:
return updated_node
new_statements = list(updated_node.body)
# Add empty line before each statement for readability
statement_lines = [
stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements
]
# Append statements at the end of the module
# This ensures they come after all functions, classes, and assignments
new_statements.extend(statement_lines)
return updated_node.with_changes(body=new_statements)
class GlobalStatementCollector(cst.CSTVisitor):
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
@ -431,40 +464,6 @@ class DottedImportCollector(cst.CSTVisitor):
self._collect_imports_from_block(node.body)
class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""
def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None:
super().__init__()
self.global_statements = global_statements
self.last_import_line = last_import_line
self.current_line = 0
self.inserted = False
def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.Module:
self.current_line += 1
# If we're right after the last import and haven't inserted yet
if self.current_line == self.last_import_line and not self.inserted:
self.inserted = True
return cst.Module(body=[updated_node, *self.global_statements])
return cst.Module(body=[updated_node])
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# If there were no imports, add at the beginning of the module
if self.last_import_line == 0 and not self.inserted:
updated_body = list(updated_node.body)
for stmt in reversed(self.global_statements):
updated_body.insert(0, stmt)
return updated_node.with_changes(body=updated_body)
return updated_node
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
@ -516,20 +515,8 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
continue
unique_global_statements.append(stmt)
mod_dst_code = dst_module_code
# Insert unique global statements if any
if unique_global_statements:
last_import_line = find_last_import_line(dst_module_code)
# Reuse already-parsed dst_module
transformer = ImportInserter(unique_global_statements, last_import_line)
# Use visit inplace, don't parse again
modified_module = dst_module.visit(transformer)
mod_dst_code = modified_module.code
# Parse the code after insertion
original_module = cst.parse_module(mod_dst_code)
else:
# No new statements to insert, reuse already-parsed dst_module
original_module = dst_module
# Reuse already-parsed dst_module
original_module = dst_module
# Parse the src_module_code once only (already done above: src_module)
# Collect assignments from the new file
@ -551,9 +538,19 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
}
new_function_order = [name for name in src_function_collector.function_order if name in new_functions]
# If there are no assignments and no new functions, return the current code
if not new_assignment_collector.assignments and not new_functions:
return mod_dst_code
# If there are no assignments, no new functions, and no global statements, return unchanged
if not new_assignment_collector.assignments and not new_functions and not unique_global_statements:
return dst_module_code
# The order of transformations matters:
# 1. Functions first - so assignments and statements can reference them
# 2. Assignments second - so they come after functions but before statements
# 3. Global statements last - so they can reference both functions and assignments
# Transform functions if any
if new_functions:
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
original_module = original_module.visit(function_transformer)
# Transform assignments if any
if new_assignment_collector.assignments:
@ -562,10 +559,12 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
)
original_module = original_module.visit(transformer)
# Transform functions if any
if new_functions:
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
original_module = original_module.visit(function_transformer)
# Insert global statements (like function calls at module level) LAST,
# after all functions and assignments are added, to ensure they can reference any
# functions or variables defined in the module
if unique_global_statements:
statement_transformer = GlobalStatementTransformer(unique_global_statements)
original_module = original_module.visit(statement_transformer)
return original_module.code

View file

@ -3063,6 +3063,84 @@ _MESSAGE_HANDLERS = {
assert result == expected
def test_add_global_assignments_function_calls_after_function_definitions():
"""Test that global function calls are placed after the functions they reference.
This test verifies the fix for a bug where LLM-generated optimization code like:
def _register(kind, factory):
_factories[kind] = factory
_register(MessageKind.ASK, lambda: "ask")
would have the _register(...) calls placed BEFORE the _register function definition,
causing NameError at module load time.
The fix ensures that new global statements (like function calls) are inserted AFTER
all class/function definitions, so they can safely reference any function defined in
the module.
"""
source_code = """\
import enum
class MessageKind(enum.StrEnum):
ASK = "ask"
REPLY = "reply"
_factories = {}
def _register(kind, factory):
_factories[kind] = factory
_register(MessageKind.ASK, lambda: "ask handler")
_register(MessageKind.REPLY, lambda: "reply handler")
def handle_message(kind):
return _factories[kind]()
"""
destination_code = """\
import enum
class MessageKind(enum.StrEnum):
ASK = "ask"
REPLY = "reply"
def handle_message(kind):
if kind == MessageKind.ASK:
return "ask"
return "reply"
"""
# Global statements (function calls) should be inserted AFTER all class/function
# definitions to ensure they can reference any function defined in the module
expected = """\
import enum
class MessageKind(enum.StrEnum):
ASK = "ask"
REPLY = "reply"
def handle_message(kind):
if kind == MessageKind.ASK:
return "ask"
return "reply"
def _register(kind, factory):
_factories[kind] = factory
_factories = {}
_register(MessageKind.ASK, lambda: "ask handler")
_register(MessageKind.REPLY, lambda: "reply handler")
"""
result = add_global_assignments(source_code, destination_code)
assert result == expected
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.

View file

@ -2116,10 +2116,12 @@ class NewClass:
print("Hello world")
```
"""
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference any classes defined in the module.
# This prevents NameError when LLM-generated optimizations like
# `_HANDLERS = {MessageKind.XXX: ...}` reference classes.
expected_code = """import numpy as np
a = 6
if 2<3:
a=4
else:
@ -2141,6 +2143,8 @@ class NewClass:
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
a = 6
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
@ -3367,6 +3371,9 @@ def hydrate_input_text_actions_with_field_names(
return updated_actions_by_task
```
'''
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference any classes defined in the module.
# This prevents NameError when LLM-generated optimizations reference classes.
expected = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
@ -3381,8 +3388,6 @@ from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
import re
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
@ -3436,6 +3441,8 @@ def hydrate_input_text_actions_with_field_names(
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
'''
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)