mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
test & correct last import index
Signed-off-by: ali <mohammed18200118@gmail.com>
This commit is contained in:
parent
e60608e4bf
commit
99cb90832a
2 changed files with 233 additions and 2 deletions
|
|
@ -121,6 +121,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
def _find_insertion_index(self, updated_node: cst.Module) -> int:
|
||||
"""Find the position of the last import statement in the top-level of the module."""
|
||||
insert_index = 0
|
||||
for i, stmt in enumerate(updated_node.body):
|
||||
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
|
||||
|
|
@ -135,9 +136,14 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
if is_top_level_import or is_conditional_import:
|
||||
insert_index = i + 1
|
||||
else:
|
||||
# stop when we find the first non-import statement
|
||||
|
||||
# Stop scanning once we reach a class or function definition.
|
||||
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
|
||||
# Without this check, a stray import later in the file
|
||||
# would incorrectly shift our insertion index below actual code definitions.
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
break
|
||||
|
||||
return insert_index
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
|
|
|
|||
|
|
@ -3228,3 +3228,228 @@ class HuggingFaceModel(Model):
|
|||
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
|
||||
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
|
||||
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
|
||||
|
||||
def test_top_level_global_assignments() -> None:
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
original_code = '''"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
prompt_engine = PromptEngine("skyvern")
|
||||
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {}
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == ActionType.INPUT_TEXT:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{task_id}:{action_id}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
'''
|
||||
main_file.write_text(original_code, encoding="utf-8")
|
||||
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
from typing import Any, Dict, List
|
||||
import re
|
||||
|
||||
# Precompiled regex for efficiently generating simple field_name from intention
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {{}}
|
||||
|
||||
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
|
||||
intention_cleanup = _INTENTION_CLEANUP_RE
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == input_text_type:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{{task_id}}:{{action_id}}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
# Use compiled regex instead of "".join(c for ...)
|
||||
field_name = intention_cleanup.sub("", field_name)
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
```
|
||||
'''
|
||||
expected = '''"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
import re
|
||||
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
prompt_engine = PromptEngine("skyvern")
|
||||
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {}
|
||||
|
||||
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
|
||||
intention_cleanup = _INTENTION_CLEANUP_RE
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == input_text_type:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{task_id}:{action_id}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
# Use compiled regex instead of "".join(c for ...)
|
||||
field_name = intention_cleanup.sub("", field_name)
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
||||
test_config = TestConfig(
|
||||
tests_root=root_dir / "tests/pytest",
|
||||
tests_project_rootdir=root_dir,
|
||||
project_root_path=root_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
||||
)
|
||||
|
||||
|
||||
new_code = main_file.read_text(encoding="utf-8")
|
||||
main_file.unlink(missing_ok=True)
|
||||
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue