test & correct last import index

Signed-off-by: ali <mohammed18200118@gmail.com>
This commit is contained in:
ali 2025-08-22 17:45:44 +03:00
parent e60608e4bf
commit 99cb90832a
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
2 changed files with 233 additions and 2 deletions

View file

@ -121,6 +121,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def _find_insertion_index(self, updated_node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(updated_node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
@ -135,9 +136,14 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
if is_top_level_import or is_conditional_import:
insert_index = i + 1
else:
# stop when we find the first non-import statement
# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break
return insert_index
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:

View file

@ -3228,3 +3228,228 @@ class HuggingFaceModel(Model):
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
def test_top_level_global_assignments() -> None:
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
original_code = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
from typing import Any, Dict, List, Tuple
import structlog
from pydantic import BaseModel
from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == ActionType.INPUT_TEXT:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
from skyvern.webeye.actions.actions import ActionType
from typing import Any, Dict, List
import re
# Precompiled regex for efficiently generating simple field_name from intention
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {{}}
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
intention_cleanup = _INTENTION_CLEANUP_RE
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == input_text_type:
action_id = action.get("action_id", "")
mapping_key = f"{{task_id}}:{{action_id}}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
# Use compiled regex instead of "".join(c for ...)
field_name = intention_cleanup.sub("", field_name)
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
```
'''
expected = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
from typing import Any, Dict, List, Tuple
import structlog
from pydantic import BaseModel
from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
import re
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
intention_cleanup = _INTENTION_CLEANUP_RE
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == input_text_type:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
# Use compiled regex instead of "".join(c for ...)
field_name = intention_cleanup.sub("", field_name)
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
'''
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir,
project_root_path=root_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
)
new_code = main_file.read_text(encoding="utf-8")
main_file.unlink(missing_ok=True)
assert new_code == expected