fix: insert global assignments after class definitions to prevent NameError

When LLM-generated optimizations include module-level code like
`_REIFIERS = {MessageKind.XXX: ...}`, the global assignment was being
inserted right after imports, BEFORE the class definition it referenced,
causing NameError at module load time.

Changes:
- GlobalAssignmentTransformer now inserts assignments after all
  class/function definitions instead of right after imports
- GlobalStatementCollector now skips AnnAssign (annotated assignments)
  so they are handled by GlobalAssignmentCollector instead
This commit is contained in:
Kevin Turcios 2026-01-24 01:37:15 -05:00
parent 1bb9d147f4
commit abfa640578
2 changed files with 74 additions and 5 deletions

View file

@ -280,8 +280,12 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
]
if assignments_to_append:
# after last top-level imports
# Start after imports, then advance past class/function definitions
# to ensure assignments can reference any classes defined in the module
insert_index = find_insertion_index_after_imports(updated_node)
for i, stmt in enumerate(new_statements):
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
insert_index = i + 1
assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
@ -331,8 +335,8 @@ class GlobalStatementCollector(cst.CSTVisitor):
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
if not self.in_function_or_class:
for statement in node.body:
# Skip imports
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
# Skip imports and assignments (both regular and annotated)
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)):
self.global_statements.append(node)
break

View file

@ -2975,11 +2975,11 @@ class MyClass:
return cached_helper(5)
"""
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference classes defined in the module
expected = """\
from typing import Any
_LOCAL_CACHE: dict[str, int] = {}
class MyClass:
def method(self):
return cached_helper(5)
@ -2992,6 +2992,71 @@ def cached_helper(x: int) -> int:
def regular_helper():
return "regular"
_LOCAL_CACHE: dict[str, int] = {}
"""
result = add_global_assignments(source_code, destination_code)
assert result == expected
def test_add_global_assignments_references_class_defined_in_module():
"""Test that global assignments referencing classes are placed after those class definitions.
This test verifies the fix for a bug where LLM-generated optimization code like:
_REIFIERS = {MessageKind.XXX: lambda d: ...}
was placed BEFORE the MessageKind class definition, causing NameError at module load.
The fix ensures that new global assignments are inserted AFTER all class/function
definitions in the module, so they can safely reference any class defined in the module.
"""
source_code = """\
import enum
class MessageKind(enum.StrEnum):
ASK = "ask"
REPLY = "reply"
_MESSAGE_HANDLERS = {
MessageKind.ASK: lambda: "ask handler",
MessageKind.REPLY: lambda: "reply handler",
}
def handle_message(kind):
return _MESSAGE_HANDLERS[kind]()
"""
destination_code = """\
import enum
class MessageKind(enum.StrEnum):
ASK = "ask"
REPLY = "reply"
def handle_message(kind):
if kind == MessageKind.ASK:
return "ask"
return "reply"
"""
# Global assignments are now inserted AFTER class/function definitions
# to ensure they can reference classes defined in the module
expected = """\
import enum
class MessageKind(enum.StrEnum):
ASK = "ask"
REPLY = "reply"
def handle_message(kind):
if kind == MessageKind.ASK:
return "ask"
return "reply"
_MESSAGE_HANDLERS = {
MessageKind.ASK: lambda: "ask handler",
MessageKind.REPLY: lambda: "reply handler",
}
"""
result = add_global_assignments(source_code, destination_code)