mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: insert global statements after function definitions to prevent NameError
When LLM-generated optimizations include module-level function calls like `_register(MessageKind.ASK, ...)`, they were being inserted right after imports, BEFORE the function definition they reference, causing NameError at module load time. Changes: - Add GlobalStatementTransformer to append global statements at module end - Reorder transformations: functions → assignments → statements - Remove unused ImportInserter class - Update test expectations to reflect new placement behavior
This commit is contained in:
parent
abfa640578
commit
50fba096f7
3 changed files with 143 additions and 59 deletions
|
|
@ -308,6 +308,39 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
class GlobalStatementTransformer(cst.CSTTransformer):
|
||||
"""Transformer that appends global statements at the end of the module.
|
||||
|
||||
This ensures that global statements (like function calls at module level) are placed
|
||||
after all functions, classes, and assignments they might reference, preventing NameError
|
||||
at module load time.
|
||||
|
||||
This transformer should be run LAST after GlobalFunctionTransformer and
|
||||
GlobalAssignmentTransformer have already added their content.
|
||||
"""
|
||||
|
||||
def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None:
|
||||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
if not self.global_statements:
|
||||
return updated_node
|
||||
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
# Add empty line before each statement for readability
|
||||
statement_lines = [
|
||||
stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements
|
||||
]
|
||||
|
||||
# Append statements at the end of the module
|
||||
# This ensures they come after all functions, classes, and assignments
|
||||
new_statements.extend(statement_lines)
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
class GlobalStatementCollector(cst.CSTVisitor):
|
||||
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
|
||||
|
||||
|
|
@ -431,40 +464,6 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
self._collect_imports_from_block(node.body)
|
||||
|
||||
|
||||
class ImportInserter(cst.CSTTransformer):
|
||||
"""Transformer that inserts global statements after the last import."""
|
||||
|
||||
def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None:
|
||||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
self.last_import_line = last_import_line
|
||||
self.current_line = 0
|
||||
self.inserted = False
|
||||
|
||||
def leave_SimpleStatementLine(
|
||||
self,
|
||||
original_node: cst.SimpleStatementLine, # noqa: ARG002
|
||||
updated_node: cst.SimpleStatementLine,
|
||||
) -> cst.Module:
|
||||
self.current_line += 1
|
||||
|
||||
# If we're right after the last import and haven't inserted yet
|
||||
if self.current_line == self.last_import_line and not self.inserted:
|
||||
self.inserted = True
|
||||
return cst.Module(body=[updated_node, *self.global_statements])
|
||||
|
||||
return cst.Module(body=[updated_node])
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# If there were no imports, add at the beginning of the module
|
||||
if self.last_import_line == 0 and not self.inserted:
|
||||
updated_body = list(updated_node.body)
|
||||
for stmt in reversed(self.global_statements):
|
||||
updated_body.insert(0, stmt)
|
||||
return updated_node.with_changes(body=updated_body)
|
||||
return updated_node
|
||||
|
||||
|
||||
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
|
||||
"""Extract global statements from source code."""
|
||||
module = cst.parse_module(source_code)
|
||||
|
|
@ -516,20 +515,8 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
continue
|
||||
unique_global_statements.append(stmt)
|
||||
|
||||
mod_dst_code = dst_module_code
|
||||
# Insert unique global statements if any
|
||||
if unique_global_statements:
|
||||
last_import_line = find_last_import_line(dst_module_code)
|
||||
# Reuse already-parsed dst_module
|
||||
transformer = ImportInserter(unique_global_statements, last_import_line)
|
||||
# Use visit inplace, don't parse again
|
||||
modified_module = dst_module.visit(transformer)
|
||||
mod_dst_code = modified_module.code
|
||||
# Parse the code after insertion
|
||||
original_module = cst.parse_module(mod_dst_code)
|
||||
else:
|
||||
# No new statements to insert, reuse already-parsed dst_module
|
||||
original_module = dst_module
|
||||
# Reuse already-parsed dst_module
|
||||
original_module = dst_module
|
||||
|
||||
# Parse the src_module_code once only (already done above: src_module)
|
||||
# Collect assignments from the new file
|
||||
|
|
@ -551,9 +538,19 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
}
|
||||
new_function_order = [name for name in src_function_collector.function_order if name in new_functions]
|
||||
|
||||
# If there are no assignments and no new functions, return the current code
|
||||
if not new_assignment_collector.assignments and not new_functions:
|
||||
return mod_dst_code
|
||||
# If there are no assignments, no new functions, and no global statements, return unchanged
|
||||
if not new_assignment_collector.assignments and not new_functions and not unique_global_statements:
|
||||
return dst_module_code
|
||||
|
||||
# The order of transformations matters:
|
||||
# 1. Functions first - so assignments and statements can reference them
|
||||
# 2. Assignments second - so they come after functions but before statements
|
||||
# 3. Global statements last - so they can reference both functions and assignments
|
||||
|
||||
# Transform functions if any
|
||||
if new_functions:
|
||||
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
|
||||
original_module = original_module.visit(function_transformer)
|
||||
|
||||
# Transform assignments if any
|
||||
if new_assignment_collector.assignments:
|
||||
|
|
@ -562,10 +559,12 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
)
|
||||
original_module = original_module.visit(transformer)
|
||||
|
||||
# Transform functions if any
|
||||
if new_functions:
|
||||
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
|
||||
original_module = original_module.visit(function_transformer)
|
||||
# Insert global statements (like function calls at module level) LAST,
|
||||
# after all functions and assignments are added, to ensure they can reference any
|
||||
# functions or variables defined in the module
|
||||
if unique_global_statements:
|
||||
statement_transformer = GlobalStatementTransformer(unique_global_statements)
|
||||
original_module = original_module.visit(statement_transformer)
|
||||
|
||||
return original_module.code
|
||||
|
||||
|
|
|
|||
|
|
@ -3063,6 +3063,84 @@ _MESSAGE_HANDLERS = {
|
|||
assert result == expected
|
||||
|
||||
|
||||
def test_add_global_assignments_function_calls_after_function_definitions():
|
||||
"""Test that global function calls are placed after the functions they reference.
|
||||
|
||||
This test verifies the fix for a bug where LLM-generated optimization code like:
|
||||
def _register(kind, factory):
|
||||
_factories[kind] = factory
|
||||
|
||||
_register(MessageKind.ASK, lambda: "ask")
|
||||
|
||||
would have the _register(...) calls placed BEFORE the _register function definition,
|
||||
causing NameError at module load time.
|
||||
|
||||
The fix ensures that new global statements (like function calls) are inserted AFTER
|
||||
all class/function definitions, so they can safely reference any function defined in
|
||||
the module.
|
||||
"""
|
||||
source_code = """\
|
||||
import enum
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
||||
_factories = {}
|
||||
|
||||
def _register(kind, factory):
|
||||
_factories[kind] = factory
|
||||
|
||||
_register(MessageKind.ASK, lambda: "ask handler")
|
||||
_register(MessageKind.REPLY, lambda: "reply handler")
|
||||
|
||||
def handle_message(kind):
|
||||
return _factories[kind]()
|
||||
"""
|
||||
|
||||
destination_code = """\
|
||||
import enum
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
||||
def handle_message(kind):
|
||||
if kind == MessageKind.ASK:
|
||||
return "ask"
|
||||
return "reply"
|
||||
"""
|
||||
|
||||
# Global statements (function calls) should be inserted AFTER all class/function
|
||||
# definitions to ensure they can reference any function defined in the module
|
||||
expected = """\
|
||||
import enum
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
||||
def handle_message(kind):
|
||||
if kind == MessageKind.ASK:
|
||||
return "ask"
|
||||
return "reply"
|
||||
|
||||
|
||||
def _register(kind, factory):
|
||||
_factories[kind] = factory
|
||||
|
||||
_factories = {}
|
||||
|
||||
|
||||
_register(MessageKind.ASK, lambda: "ask handler")
|
||||
|
||||
_register(MessageKind.REPLY, lambda: "reply handler")
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
||||
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
||||
|
||||
|
|
|
|||
|
|
@ -2116,10 +2116,12 @@ class NewClass:
|
|||
print("Hello world")
|
||||
```
|
||||
"""
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference any classes defined in the module.
|
||||
# This prevents NameError when LLM-generated optimizations like
|
||||
# `_HANDLERS = {MessageKind.XXX: ...}` reference classes.
|
||||
expected_code = """import numpy as np
|
||||
|
||||
a = 6
|
||||
|
||||
if 2<3:
|
||||
a=4
|
||||
else:
|
||||
|
|
@ -2141,6 +2143,8 @@ class NewClass:
|
|||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
a = 6
|
||||
"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||
code_path.write_text(original_code, encoding="utf-8")
|
||||
|
|
@ -3367,6 +3371,9 @@ def hydrate_input_text_actions_with_field_names(
|
|||
return updated_actions_by_task
|
||||
```
|
||||
'''
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference any classes defined in the module.
|
||||
# This prevents NameError when LLM-generated optimizations reference classes.
|
||||
expected = '''"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
|
@ -3381,8 +3388,6 @@ from skyvern.forge.sdk.prompting import PromptEngine
|
|||
from skyvern.webeye.actions.actions import ActionType
|
||||
import re
|
||||
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
|
|
@ -3436,6 +3441,8 @@ def hydrate_input_text_actions_with_field_names(
|
|||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
||||
|
|
|
|||
Loading…
Reference in a new issue