fix: insert global assignments after class definitions to prevent NameError
When LLM-generated optimizations include module-level code like
`_REIFIERS = {MessageKind.XXX: ...}`, the global assignment was being
inserted right after imports, BEFORE the class definition it referenced,
causing NameError at module load time.
Changes:
- GlobalAssignmentTransformer now inserts assignments after all
class/function definitions instead of right after imports
- GlobalStatementCollector now skips AnnAssign (annotated assignments)
so they are handled by GlobalAssignmentCollector instead
This commit is contained in:
parent
1bb9d147f4
commit
abfa640578
2 changed files with 74 additions and 5 deletions
|
|
@ -280,8 +280,12 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
]
|
||||
|
||||
if assignments_to_append:
|
||||
# after last top-level imports
|
||||
# Start after imports, then advance past class/function definitions
|
||||
# to ensure assignments can reference any classes defined in the module
|
||||
insert_index = find_insertion_index_after_imports(updated_node)
|
||||
for i, stmt in enumerate(new_statements):
|
||||
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
||||
insert_index = i + 1
|
||||
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
|
|
@ -331,8 +335,8 @@ class GlobalStatementCollector(cst.CSTVisitor):
|
|||
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
||||
if not self.in_function_or_class:
|
||||
for statement in node.body:
|
||||
# Skip imports
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
|
||||
# Skip imports and assignments (both regular and annotated)
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)):
|
||||
self.global_statements.append(node)
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -2975,11 +2975,11 @@ class MyClass:
|
|||
return cached_helper(5)
|
||||
"""
|
||||
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference classes defined in the module
|
||||
expected = """\
|
||||
from typing import Any
|
||||
|
||||
_LOCAL_CACHE: dict[str, int] = {}
|
||||
|
||||
class MyClass:
|
||||
def method(self):
|
||||
return cached_helper(5)
|
||||
|
|
@ -2992,6 +2992,71 @@ def cached_helper(x: int) -> int:
|
|||
|
||||
def regular_helper():
|
||||
return "regular"
|
||||
|
||||
_LOCAL_CACHE: dict[str, int] = {}
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_add_global_assignments_references_class_defined_in_module():
|
||||
"""Test that global assignments referencing classes are placed after those class definitions.
|
||||
|
||||
This test verifies the fix for a bug where LLM-generated optimization code like:
|
||||
_REIFIERS = {MessageKind.XXX: lambda d: ...}
|
||||
was placed BEFORE the MessageKind class definition, causing NameError at module load.
|
||||
|
||||
The fix ensures that new global assignments are inserted AFTER all class/function
|
||||
definitions in the module, so they can safely reference any class defined in the module.
|
||||
"""
|
||||
source_code = """\
|
||||
import enum
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
||||
_MESSAGE_HANDLERS = {
|
||||
MessageKind.ASK: lambda: "ask handler",
|
||||
MessageKind.REPLY: lambda: "reply handler",
|
||||
}
|
||||
|
||||
def handle_message(kind):
|
||||
return _MESSAGE_HANDLERS[kind]()
|
||||
"""
|
||||
|
||||
destination_code = """\
|
||||
import enum
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
||||
def handle_message(kind):
|
||||
if kind == MessageKind.ASK:
|
||||
return "ask"
|
||||
return "reply"
|
||||
"""
|
||||
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference classes defined in the module
|
||||
expected = """\
|
||||
import enum
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
||||
def handle_message(kind):
|
||||
if kind == MessageKind.ASK:
|
||||
return "ask"
|
||||
return "reply"
|
||||
|
||||
_MESSAGE_HANDLERS = {
|
||||
MessageKind.ASK: lambda: "ask handler",
|
||||
MessageKind.REPLY: lambda: "reply handler",
|
||||
}
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
|
|
|
|||
Loading…
Reference in a new issue