mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
refactor: smarter placement of global assignments based on dependencies
Assignments that don't reference module-level definitions are now placed right after imports. Only assignments that reference classes/functions are placed after those definitions to prevent NameError.
This commit is contained in:
parent
257c5f2b8f
commit
7b33e8b7f6
5 changed files with 74 additions and 46 deletions
|
|
@ -115,6 +115,21 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
def collect_referenced_names(node: cst.CSTNode) -> set[str]:
|
||||
"""Collect all names referenced in a CST node using recursive traversal."""
|
||||
names: set[str] = set()
|
||||
|
||||
def _collect(n: cst.CSTNode) -> None:
|
||||
if isinstance(n, cst.Name):
|
||||
names.add(n.value)
|
||||
# Recursively process all children
|
||||
for child in n.children:
|
||||
_collect(child)
|
||||
|
||||
_collect(node)
|
||||
return names
|
||||
|
||||
|
||||
class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||
"""Collects all global assignment statements."""
|
||||
|
||||
|
|
@ -274,37 +289,69 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
# Find assignments to append
|
||||
assignments_to_append = [
|
||||
self.new_assignments[name]
|
||||
(name, self.new_assignments[name])
|
||||
for name in self.new_assignment_order
|
||||
if name not in self.processed_assignments and name in self.new_assignments
|
||||
]
|
||||
|
||||
if assignments_to_append:
|
||||
# Start after imports, then advance past class/function definitions
|
||||
# to ensure assignments can reference any classes defined in the module
|
||||
if not assignments_to_append:
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
# Collect all class and function names defined in the module
|
||||
# These are the names that assignments might reference
|
||||
module_defined_names: set[str] = set()
|
||||
for stmt in new_statements:
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
module_defined_names.add(stmt.name.value)
|
||||
|
||||
# Partition assignments: those that reference module definitions go at the end,
|
||||
# those that don't can go right after imports
|
||||
assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
|
||||
assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
|
||||
|
||||
for name, assignment in assignments_to_append:
|
||||
# Get the value being assigned
|
||||
if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None:
|
||||
value_node = assignment.value
|
||||
else:
|
||||
# No value to analyze, safe to place after imports
|
||||
assignments_after_imports.append((name, assignment))
|
||||
continue
|
||||
|
||||
# Collect names referenced in the assignment value
|
||||
referenced_names = collect_referenced_names(value_node)
|
||||
|
||||
# Check if any referenced names are module-level definitions
|
||||
if referenced_names & module_defined_names:
|
||||
# This assignment references a class/function, place it after definitions
|
||||
assignments_after_definitions.append((name, assignment))
|
||||
else:
|
||||
# Safe to place right after imports
|
||||
assignments_after_imports.append((name, assignment))
|
||||
|
||||
# Insert assignments that don't depend on module definitions right after imports
|
||||
if assignments_after_imports:
|
||||
insert_index = find_insertion_index_after_imports(updated_node)
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for _, assignment in assignments_after_imports
|
||||
]
|
||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||
|
||||
# Insert assignments that depend on module definitions after all class/function definitions
|
||||
if assignments_after_definitions:
|
||||
# Find the position after the last function or class definition
|
||||
insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements))
|
||||
for i, stmt in enumerate(new_statements):
|
||||
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
||||
insert_index = i + 1
|
||||
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for assignment in assignments_to_append
|
||||
for _, assignment in assignments_after_definitions
|
||||
]
|
||||
|
||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||
|
||||
# Add a blank line after the last assignment if needed
|
||||
after_index = insert_index + len(assignment_lines)
|
||||
if after_index < len(new_statements):
|
||||
next_stmt = new_statements[after_index]
|
||||
# If there's no empty line, add one
|
||||
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
|
||||
if not has_empty:
|
||||
new_statements[after_index] = next_stmt.with_changes(
|
||||
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
|
||||
)
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2975,11 +2975,11 @@ class MyClass:
|
|||
return cached_helper(5)
|
||||
"""
|
||||
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference classes defined in the module
|
||||
expected = """\
|
||||
from typing import Any
|
||||
|
||||
_LOCAL_CACHE: dict[str, int] = {}
|
||||
|
||||
class MyClass:
|
||||
def method(self):
|
||||
return cached_helper(5)
|
||||
|
|
@ -2992,8 +2992,6 @@ def cached_helper(x: int) -> int:
|
|||
|
||||
def regular_helper():
|
||||
return "regular"
|
||||
|
||||
_LOCAL_CACHE: dict[str, int] = {}
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
|
|
@ -3111,11 +3109,11 @@ def handle_message(kind):
|
|||
return "reply"
|
||||
"""
|
||||
|
||||
# Global statements (function calls) should be inserted AFTER all class/function
|
||||
# definitions to ensure they can reference any function defined in the module
|
||||
expected = """\
|
||||
import enum
|
||||
|
||||
_factories = {}
|
||||
|
||||
class MessageKind(enum.StrEnum):
|
||||
ASK = "ask"
|
||||
REPLY = "reply"
|
||||
|
|
@ -3129,8 +3127,6 @@ def handle_message(kind):
|
|||
def _register(kind, factory):
|
||||
_factories[kind] = factory
|
||||
|
||||
_factories = {}
|
||||
|
||||
|
||||
_register(MessageKind.ASK, lambda: "ask handler")
|
||||
|
||||
|
|
|
|||
|
|
@ -2116,12 +2116,9 @@ class NewClass:
|
|||
print("Hello world")
|
||||
```
|
||||
"""
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference any classes defined in the module.
|
||||
# This prevents NameError when LLM-generated optimizations like
|
||||
# `_HANDLERS = {MessageKind.XXX: ...}` reference classes.
|
||||
expected_code = """import numpy as np
|
||||
|
||||
a = 6
|
||||
if 2<3:
|
||||
a=4
|
||||
else:
|
||||
|
|
@ -2143,8 +2140,6 @@ class NewClass:
|
|||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
a = 6
|
||||
"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||
code_path.write_text(original_code, encoding="utf-8")
|
||||
|
|
@ -3371,9 +3366,6 @@ def hydrate_input_text_actions_with_field_names(
|
|||
return updated_actions_by_task
|
||||
```
|
||||
'''
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to ensure they can reference any classes defined in the module.
|
||||
# This prevents NameError when LLM-generated optimizations reference classes.
|
||||
expected = '''"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
|
@ -3388,6 +3380,8 @@ from skyvern.forge.sdk.prompting import PromptEngine
|
|||
from skyvern.webeye.actions.actions import ActionType
|
||||
import re
|
||||
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
|
|
@ -3441,8 +3435,6 @@ def hydrate_input_text_actions_with_field_names(
|
|||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
||||
|
|
|
|||
|
|
@ -218,10 +218,6 @@ def test_no_targets_found() -> None:
|
|||
def target(self):
|
||||
pass
|
||||
"""
|
||||
# Nested class methods (MyClass.Inner.target) aren't directly targetable,
|
||||
# but the outer class is kept when the qualified name starts with it.
|
||||
# This is because the dependency tracking marks "MyClass" as used when it
|
||||
# sees "MyClass.Inner.target" as a target function.
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
|
|
|||
|
|
@ -124,15 +124,14 @@ def _get_string_usage(text: str) -> Usage:
|
|||
|
||||
helper_file.unlink(missing_ok=True)
|
||||
main_file.unlink(missing_ok=True)
|
||||
|
||||
# Global assignments are now inserted AFTER class/function definitions
|
||||
# to prevent NameError when they reference classes or functions.
|
||||
# See commit 50fba096 for details.
|
||||
|
||||
expected_helper = """import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
|
||||
|
||||
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
|
||||
|
||||
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
||||
|
||||
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
||||
|
|
@ -159,8 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|||
tokens += len(part.data)
|
||||
|
||||
return tokens
|
||||
|
||||
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
|
||||
"""
|
||||
|
||||
assert new_code.rstrip() == original_main.rstrip() # No Change
|
||||
|
|
|
|||
Loading…
Reference in a new issue