refactor: smarter placement of global assignments based on dependencies
Assignments that don't reference module-level definitions are now placed right after imports. Only assignments that reference classes/functions are placed after those definitions to prevent NameError.
This commit is contained in:
parent
257c5f2b8f
commit
7b33e8b7f6
5 changed files with 74 additions and 46 deletions
|
|
@ -115,6 +115,21 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
||||||
return updated_node.with_changes(body=new_statements)
|
return updated_node.with_changes(body=new_statements)
|
||||||
|
|
||||||
|
|
||||||
|
def collect_referenced_names(node: cst.CSTNode) -> set[str]:
|
||||||
|
"""Collect all names referenced in a CST node using recursive traversal."""
|
||||||
|
names: set[str] = set()
|
||||||
|
|
||||||
|
def _collect(n: cst.CSTNode) -> None:
|
||||||
|
if isinstance(n, cst.Name):
|
||||||
|
names.add(n.value)
|
||||||
|
# Recursively process all children
|
||||||
|
for child in n.children:
|
||||||
|
_collect(child)
|
||||||
|
|
||||||
|
_collect(node)
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
class GlobalAssignmentCollector(cst.CSTVisitor):
|
class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||||
"""Collects all global assignment statements."""
|
"""Collects all global assignment statements."""
|
||||||
|
|
||||||
|
|
@ -274,37 +289,69 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
||||||
|
|
||||||
# Find assignments to append
|
# Find assignments to append
|
||||||
assignments_to_append = [
|
assignments_to_append = [
|
||||||
self.new_assignments[name]
|
(name, self.new_assignments[name])
|
||||||
for name in self.new_assignment_order
|
for name in self.new_assignment_order
|
||||||
if name not in self.processed_assignments and name in self.new_assignments
|
if name not in self.processed_assignments and name in self.new_assignments
|
||||||
]
|
]
|
||||||
|
|
||||||
if assignments_to_append:
|
if not assignments_to_append:
|
||||||
# Start after imports, then advance past class/function definitions
|
return updated_node.with_changes(body=new_statements)
|
||||||
# to ensure assignments can reference any classes defined in the module
|
|
||||||
|
# Collect all class and function names defined in the module
|
||||||
|
# These are the names that assignments might reference
|
||||||
|
module_defined_names: set[str] = set()
|
||||||
|
for stmt in new_statements:
|
||||||
|
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||||
|
module_defined_names.add(stmt.name.value)
|
||||||
|
|
||||||
|
# Partition assignments: those that reference module definitions go at the end,
|
||||||
|
# those that don't can go right after imports
|
||||||
|
assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
|
||||||
|
assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
|
||||||
|
|
||||||
|
for name, assignment in assignments_to_append:
|
||||||
|
# Get the value being assigned
|
||||||
|
if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None:
|
||||||
|
value_node = assignment.value
|
||||||
|
else:
|
||||||
|
# No value to analyze, safe to place after imports
|
||||||
|
assignments_after_imports.append((name, assignment))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Collect names referenced in the assignment value
|
||||||
|
referenced_names = collect_referenced_names(value_node)
|
||||||
|
|
||||||
|
# Check if any referenced names are module-level definitions
|
||||||
|
if referenced_names & module_defined_names:
|
||||||
|
# This assignment references a class/function, place it after definitions
|
||||||
|
assignments_after_definitions.append((name, assignment))
|
||||||
|
else:
|
||||||
|
# Safe to place right after imports
|
||||||
|
assignments_after_imports.append((name, assignment))
|
||||||
|
|
||||||
|
# Insert assignments that don't depend on module definitions right after imports
|
||||||
|
if assignments_after_imports:
|
||||||
insert_index = find_insertion_index_after_imports(updated_node)
|
insert_index = find_insertion_index_after_imports(updated_node)
|
||||||
|
assignment_lines = [
|
||||||
|
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||||
|
for _, assignment in assignments_after_imports
|
||||||
|
]
|
||||||
|
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||||
|
|
||||||
|
# Insert assignments that depend on module definitions after all class/function definitions
|
||||||
|
if assignments_after_definitions:
|
||||||
|
# Find the position after the last function or class definition
|
||||||
|
insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements))
|
||||||
for i, stmt in enumerate(new_statements):
|
for i, stmt in enumerate(new_statements):
|
||||||
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
||||||
insert_index = i + 1
|
insert_index = i + 1
|
||||||
|
|
||||||
assignment_lines = [
|
assignment_lines = [
|
||||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||||
for assignment in assignments_to_append
|
for _, assignment in assignments_after_definitions
|
||||||
]
|
]
|
||||||
|
|
||||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||||
|
|
||||||
# Add a blank line after the last assignment if needed
|
|
||||||
after_index = insert_index + len(assignment_lines)
|
|
||||||
if after_index < len(new_statements):
|
|
||||||
next_stmt = new_statements[after_index]
|
|
||||||
# If there's no empty line, add one
|
|
||||||
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
|
|
||||||
if not has_empty:
|
|
||||||
new_statements[after_index] = next_stmt.with_changes(
|
|
||||||
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
|
|
||||||
)
|
|
||||||
|
|
||||||
return updated_node.with_changes(body=new_statements)
|
return updated_node.with_changes(body=new_statements)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2975,11 +2975,11 @@ class MyClass:
|
||||||
return cached_helper(5)
|
return cached_helper(5)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Global assignments are now inserted AFTER class/function definitions
|
|
||||||
# to ensure they can reference classes defined in the module
|
|
||||||
expected = """\
|
expected = """\
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
_LOCAL_CACHE: dict[str, int] = {}
|
||||||
|
|
||||||
class MyClass:
|
class MyClass:
|
||||||
def method(self):
|
def method(self):
|
||||||
return cached_helper(5)
|
return cached_helper(5)
|
||||||
|
|
@ -2992,8 +2992,6 @@ def cached_helper(x: int) -> int:
|
||||||
|
|
||||||
def regular_helper():
|
def regular_helper():
|
||||||
return "regular"
|
return "regular"
|
||||||
|
|
||||||
_LOCAL_CACHE: dict[str, int] = {}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = add_global_assignments(source_code, destination_code)
|
result = add_global_assignments(source_code, destination_code)
|
||||||
|
|
@ -3111,11 +3109,11 @@ def handle_message(kind):
|
||||||
return "reply"
|
return "reply"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Global statements (function calls) should be inserted AFTER all class/function
|
|
||||||
# definitions to ensure they can reference any function defined in the module
|
|
||||||
expected = """\
|
expected = """\
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
_factories = {}
|
||||||
|
|
||||||
class MessageKind(enum.StrEnum):
|
class MessageKind(enum.StrEnum):
|
||||||
ASK = "ask"
|
ASK = "ask"
|
||||||
REPLY = "reply"
|
REPLY = "reply"
|
||||||
|
|
@ -3129,8 +3127,6 @@ def handle_message(kind):
|
||||||
def _register(kind, factory):
|
def _register(kind, factory):
|
||||||
_factories[kind] = factory
|
_factories[kind] = factory
|
||||||
|
|
||||||
_factories = {}
|
|
||||||
|
|
||||||
|
|
||||||
_register(MessageKind.ASK, lambda: "ask handler")
|
_register(MessageKind.ASK, lambda: "ask handler")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2116,12 +2116,9 @@ class NewClass:
|
||||||
print("Hello world")
|
print("Hello world")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# Global assignments are now inserted AFTER class/function definitions
|
|
||||||
# to ensure they can reference any classes defined in the module.
|
|
||||||
# This prevents NameError when LLM-generated optimizations like
|
|
||||||
# `_HANDLERS = {MessageKind.XXX: ...}` reference classes.
|
|
||||||
expected_code = """import numpy as np
|
expected_code = """import numpy as np
|
||||||
|
|
||||||
|
a = 6
|
||||||
if 2<3:
|
if 2<3:
|
||||||
a=4
|
a=4
|
||||||
else:
|
else:
|
||||||
|
|
@ -2143,8 +2140,6 @@ class NewClass:
|
||||||
return "I am still old"
|
return "I am still old"
|
||||||
def new_function2(value):
|
def new_function2(value):
|
||||||
return cst.ensure_type(value, str)
|
return cst.ensure_type(value, str)
|
||||||
|
|
||||||
a = 6
|
|
||||||
"""
|
"""
|
||||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||||
code_path.write_text(original_code, encoding="utf-8")
|
code_path.write_text(original_code, encoding="utf-8")
|
||||||
|
|
@ -3371,9 +3366,6 @@ def hydrate_input_text_actions_with_field_names(
|
||||||
return updated_actions_by_task
|
return updated_actions_by_task
|
||||||
```
|
```
|
||||||
'''
|
'''
|
||||||
# Global assignments are now inserted AFTER class/function definitions
|
|
||||||
# to ensure they can reference any classes defined in the module.
|
|
||||||
# This prevents NameError when LLM-generated optimizations reference classes.
|
|
||||||
expected = '''"""
|
expected = '''"""
|
||||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||||
"""
|
"""
|
||||||
|
|
@ -3388,6 +3380,8 @@ from skyvern.forge.sdk.prompting import PromptEngine
|
||||||
from skyvern.webeye.actions.actions import ActionType
|
from skyvern.webeye.actions.actions import ActionType
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||||
|
|
||||||
LOG = structlog.get_logger(__name__)
|
LOG = structlog.get_logger(__name__)
|
||||||
|
|
||||||
# Initialize prompt engine
|
# Initialize prompt engine
|
||||||
|
|
@ -3441,8 +3435,6 @@ def hydrate_input_text_actions_with_field_names(
|
||||||
updated_actions_by_task[task_id] = updated_actions
|
updated_actions_by_task[task_id] = updated_actions
|
||||||
|
|
||||||
return updated_actions_by_task
|
return updated_actions_by_task
|
||||||
|
|
||||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
||||||
|
|
|
||||||
|
|
@ -218,10 +218,6 @@ def test_no_targets_found() -> None:
|
||||||
def target(self):
|
def target(self):
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
# Nested class methods (MyClass.Inner.target) aren't directly targetable,
|
|
||||||
# but the outer class is kept when the qualified name starts with it.
|
|
||||||
# This is because the dependency tracking marks "MyClass" as used when it
|
|
||||||
# sees "MyClass.Inner.target" as a target function.
|
|
||||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||||
expected = dedent("""
|
expected = dedent("""
|
||||||
class MyClass:
|
class MyClass:
|
||||||
|
|
|
||||||
|
|
@ -124,15 +124,14 @@ def _get_string_usage(text: str) -> Usage:
|
||||||
|
|
||||||
helper_file.unlink(missing_ok=True)
|
helper_file.unlink(missing_ok=True)
|
||||||
main_file.unlink(missing_ok=True)
|
main_file.unlink(missing_ok=True)
|
||||||
|
|
||||||
# Global assignments are now inserted AFTER class/function definitions
|
|
||||||
# to prevent NameError when they reference classes or functions.
|
|
||||||
# See commit 50fba096 for details.
|
|
||||||
expected_helper = """import re
|
expected_helper = """import re
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
|
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
|
||||||
|
|
||||||
|
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
|
||||||
|
|
||||||
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
||||||
|
|
||||||
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
||||||
|
|
@ -159,8 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
||||||
tokens += len(part.data)
|
tokens += len(part.data)
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert new_code.rstrip() == original_main.rstrip() # No Change
|
assert new_code.rstrip() == original_main.rstrip() # No Change
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue