From abfa640578a899c0525089aab54cda2d1fe0336a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 24 Jan 2026 01:37:15 -0500 Subject: [PATCH] fix: insert global assignments after class definitions to prevent NameError When LLM-generated optimizations include module-level code like `_REIFIERS = {MessageKind.XXX: ...}`, the global assignment was being inserted right after imports, BEFORE the class definition it referenced, causing NameError at module load time. Changes: - GlobalAssignmentTransformer now inserts assignments after all class/function definitions instead of right after imports - GlobalStatementCollector now skips AnnAssign (annotated assignments) so they are handled by GlobalAssignmentCollector instead --- codeflash/code_utils/code_extractor.py | 10 ++-- tests/test_code_context_extractor.py | 69 +++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 5 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0bbcc5908..f72c0caa8 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -280,8 +280,12 @@ class GlobalAssignmentTransformer(cst.CSTTransformer): ] if assignments_to_append: - # after last top-level imports + # Start after imports, then advance past class/function definitions + # to ensure assignments can reference any classes defined in the module insert_index = find_insertion_index_after_imports(updated_node) + for i, stmt in enumerate(new_statements): + if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)): + insert_index = i + 1 assignment_lines = [ cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) @@ -331,8 +335,8 @@ class GlobalStatementCollector(cst.CSTVisitor): def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: if not self.in_function_or_class: for statement in node.body: - # Skip imports - if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)): + # Skip imports and assignments (both regular and annotated) + if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)): self.global_statements.append(node) break diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 031a22524..7b22f4c48 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -2975,11 +2975,11 @@ class MyClass: return cached_helper(5) """ + # Global assignments are now inserted AFTER class/function definitions + # to ensure they can reference classes defined in the module expected = """\ from typing import Any -_LOCAL_CACHE: dict[str, int] = {} - class MyClass: def method(self): return cached_helper(5) @@ -2992,6 +2992,71 @@ def cached_helper(x: int) -> int: def regular_helper(): return "regular" + +_LOCAL_CACHE: dict[str, int] = {} +""" + + result = add_global_assignments(source_code, destination_code) + assert result == expected + + +def test_add_global_assignments_references_class_defined_in_module(): + """Test that global assignments referencing classes are placed after those class definitions. + + This test verifies the fix for a bug where LLM-generated optimization code like: + _REIFIERS = {MessageKind.XXX: lambda d: ...} + was placed BEFORE the MessageKind class definition, causing NameError at module load. + + The fix ensures that new global assignments are inserted AFTER all class/function + definitions in the module, so they can safely reference any class defined in the module. + """ + source_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +_MESSAGE_HANDLERS = { + MessageKind.ASK: lambda: "ask handler", + MessageKind.REPLY: lambda: "reply handler", +} + +def handle_message(kind): + return _MESSAGE_HANDLERS[kind]() +""" + + destination_code = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" +""" + + # Global assignments are now inserted AFTER class/function definitions + # to ensure they can reference classes defined in the module + expected = """\ +import enum + +class MessageKind(enum.StrEnum): + ASK = "ask" + REPLY = "reply" + +def handle_message(kind): + if kind == MessageKind.ASK: + return "ask" + return "reply" + +_MESSAGE_HANDLERS = { + MessageKind.ASK: lambda: "ask handler", + MessageKind.REPLY: lambda: "reply handler", +} """ result = add_global_assignments(source_code, destination_code)