Merge branch 'main' of github.com:codeflash-ai/codeflash into fix/duplicate-global-assignments-when-reverting-helpers
This commit is contained in:
commit
28f50cc1e0
11 changed files with 386 additions and 101 deletions
|
|
@ -81,6 +81,19 @@ class AiServiceClient:
|
|||
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
return response
|
||||
|
||||
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
|
||||
candidates: list[OptimizedCandidate] = []
|
||||
for opt in optimizations_json:
|
||||
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
|
||||
if not code.code_strings:
|
||||
continue
|
||||
candidates.append(
|
||||
OptimizedCandidate(
|
||||
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
def optimize_python_code( # noqa: D417
|
||||
self,
|
||||
source_code: str,
|
||||
|
|
@ -135,14 +148,7 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
end_time = time.perf_counter()
|
||||
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
)
|
||||
for opt in optimizations_json
|
||||
]
|
||||
return self._get_valid_candidates(optimizations_json)
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
@ -205,14 +211,7 @@ class AiServiceClient:
|
|||
optimizations_json = response.json()["optimizations"]
|
||||
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
|
||||
console.rule()
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
)
|
||||
for opt in optimizations_json
|
||||
]
|
||||
return self._get_valid_candidates(optimizations_json)
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
@ -262,14 +261,17 @@ class AiServiceClient:
|
|||
refined_optimizations = response.json()["refinements"]
|
||||
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
|
||||
console.rule()
|
||||
|
||||
refinements = self._get_valid_candidates(refined_optimizations)
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"][:-4] + "refi",
|
||||
source_code=c.source_code,
|
||||
explanation=c.explanation,
|
||||
optimization_id=c.optimization_id[:-4] + "refi",
|
||||
)
|
||||
for opt in refined_optimizations
|
||||
for c in refinements
|
||||
]
|
||||
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -119,6 +120,32 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def _find_insertion_index(self, updated_node: cst.Module) -> int:
|
||||
"""Find the position of the last import statement in the top-level of the module."""
|
||||
insert_index = 0
|
||||
for i, stmt in enumerate(updated_node.body):
|
||||
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
|
||||
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
|
||||
)
|
||||
|
||||
is_conditional_import = isinstance(stmt, cst.If) and all(
|
||||
isinstance(inner, cst.SimpleStatementLine)
|
||||
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
|
||||
for inner in stmt.body.body
|
||||
)
|
||||
|
||||
if is_top_level_import or is_conditional_import:
|
||||
insert_index = i + 1
|
||||
|
||||
# Stop scanning once we reach a class or function definition.
|
||||
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
|
||||
# Without this check, a stray import later in the file
|
||||
# would incorrectly shift our insertion index below actual code definitions.
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
break
|
||||
|
||||
return insert_index
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
|
@ -131,18 +158,26 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
]
|
||||
|
||||
if assignments_to_append:
|
||||
# Add a blank line before appending new assignments if needed
|
||||
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
|
||||
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
|
||||
new_statements.pop() # Remove the Pass statement but keep the empty line
|
||||
# after last top-level imports
|
||||
insert_index = self._find_insertion_index(updated_node)
|
||||
|
||||
# Add the new assignments
|
||||
new_statements.extend(
|
||||
[
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for assignment in assignments_to_append
|
||||
]
|
||||
)
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for assignment in assignments_to_append
|
||||
]
|
||||
|
||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||
|
||||
# Add a blank line after the last assignment if needed
|
||||
after_index = insert_index + len(assignment_lines)
|
||||
if after_index < len(new_statements):
|
||||
next_stmt = new_statements[after_index]
|
||||
# If there's no empty line, add one
|
||||
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
|
||||
if not has_empty:
|
||||
new_statements[after_index] = next_stmt.with_changes(
|
||||
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
|
||||
)
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
|
@ -341,6 +376,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
new_added_global_statements = extract_global_statements(src_module_code)
|
||||
existing_global_statements = extract_global_statements(dst_module_code)
|
||||
|
||||
# make sure we don't have any staments applited multiple times in the global level.
|
||||
unique_global_statements = [
|
||||
stmt
|
||||
for stmt in new_added_global_statements
|
||||
|
|
|
|||
|
|
@ -412,6 +412,7 @@ def replace_function_definitions_in_module(
|
|||
module_abspath: Path,
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
|
||||
project_root_path: Path,
|
||||
global_assignments_added_before: bool = False, # noqa: FBT001, FBT002
|
||||
) -> bool:
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
|
@ -421,7 +422,7 @@ def replace_function_definitions_in_module(
|
|||
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
|
||||
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
|
||||
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
|
||||
add_global_assignments(code_to_apply, source_code),
|
||||
add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code,
|
||||
function_names,
|
||||
code_to_apply,
|
||||
module_abspath,
|
||||
|
|
|
|||
|
|
@ -537,6 +537,7 @@ def revert_unused_helper_functions(
|
|||
module_abspath=file_path,
|
||||
preexisting_objects=set(), # Empty set since we're reverting
|
||||
project_root_path=project_root,
|
||||
global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
|
||||
)
|
||||
|
||||
if reverted_code:
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ def initialize_function_optimization(
|
|||
|
||||
if count == 0:
|
||||
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
|
||||
cleanup_the_optimizer(server)
|
||||
server.cleanup_the_optimizer()
|
||||
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
|
||||
|
||||
fto = optimizable_funcs.popitem()[1][0]
|
||||
|
|
@ -217,6 +217,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
|
|||
|
||||
|
||||
@server.feature("performFunctionOptimization")
|
||||
@server.thread()
|
||||
def perform_function_optimization( # noqa: PLR0911
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationParams
|
||||
) -> dict[str, str]:
|
||||
|
|
@ -337,14 +338,4 @@ def perform_function_optimization( # noqa: PLR0911
|
|||
"explanation": best_optimization.explanation_v2,
|
||||
}
|
||||
finally:
|
||||
cleanup_the_optimizer(server)
|
||||
|
||||
|
||||
def cleanup_the_optimizer(server: CodeflashLanguageServer) -> None:
|
||||
server.optimizer.cleanup_temporary_paths()
|
||||
# restore args and test cfg
|
||||
if server.optimizer.original_args_and_test_cfg:
|
||||
server.optimizer.args, server.optimizer.test_cfg = server.optimizer.original_args_and_test_cfg
|
||||
server.optimizer.args.function = None
|
||||
server.optimizer.current_worktree = None
|
||||
server.optimizer.current_function_optimizer = None
|
||||
server.cleanup_the_optimizer()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING, Any, Optional, TextIO
|
||||
|
||||
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
|
||||
from pygls import uris
|
||||
from pygls.protocol import LanguageServerProtocol, lsp_method
|
||||
from pygls.server import LanguageServer
|
||||
from pygls.server import LanguageServer, StdOutTransportAdapter, aio_readline
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lsprotocol.types import InitializeParams, InitializeResult
|
||||
|
|
@ -81,3 +83,39 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
# Send log message to client (appears in output channel)
|
||||
log_params = LogMessageParams(type=lsp_message_type, message=message)
|
||||
self.lsp.notify("window/logMessage", log_params)
|
||||
|
||||
def cleanup_the_optimizer(self) -> None:
|
||||
try:
|
||||
self.optimizer.cleanup_temporary_paths()
|
||||
# restore args and test cfg
|
||||
if self.optimizer.original_args_and_test_cfg:
|
||||
self.optimizer.args, self.optimizer.test_cfg = self.optimizer.original_args_and_test_cfg
|
||||
self.optimizer.args.function = None
|
||||
self.optimizer.current_worktree = None
|
||||
self.optimizer.current_function_optimizer = None
|
||||
except Exception:
|
||||
self.show_message_log("Failed to cleanup optimizer", "Error")
|
||||
|
||||
def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = None) -> None:
|
||||
self.show_message_log("Starting IO server", "Info")
|
||||
|
||||
self._stop_event = Event()
|
||||
transport = StdOutTransportAdapter(stdin or sys.stdin.buffer, stdout or sys.stdout.buffer)
|
||||
self.lsp.connection_made(transport)
|
||||
try:
|
||||
self.loop.run_until_complete(
|
||||
aio_readline(
|
||||
self.loop,
|
||||
self.thread_pool_executor,
|
||||
self._stop_event,
|
||||
stdin or sys.stdin.buffer,
|
||||
self.lsp.data_received,
|
||||
)
|
||||
)
|
||||
except BrokenPipeError:
|
||||
self.show_message_log("Connection to the client is lost! Shutting down the server.", "Error")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
finally:
|
||||
self.cleanup_the_optimizer()
|
||||
self.shutdown()
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from re import Pattern
|
|||
from typing import Annotated, Optional, cast
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
|
|
@ -239,10 +239,14 @@ class CodeStringsMarkdown(BaseModel):
|
|||
"""
|
||||
matches = markdown_pattern.findall(markdown_code)
|
||||
results = CodeStringsMarkdown()
|
||||
for file_path, code in matches:
|
||||
path = file_path.strip()
|
||||
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
|
||||
return results
|
||||
try:
|
||||
for file_path, code in matches:
|
||||
path = file_path.strip()
|
||||
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
|
||||
return results # noqa: TRY300
|
||||
except ValidationError:
|
||||
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
|
||||
return CodeStringsMarkdown()
|
||||
|
||||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1354,6 +1354,7 @@ class FunctionOptimizer:
|
|||
return
|
||||
|
||||
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
|
||||
logger.info("Reverting code and helpers...")
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2099,6 +2099,8 @@ print("Hello world")
|
|||
"""
|
||||
expected_code = """import numpy as np
|
||||
|
||||
a = 6
|
||||
|
||||
if 2<3:
|
||||
a=4
|
||||
else:
|
||||
|
|
@ -2120,8 +2122,6 @@ class NewClass:
|
|||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
a = 6
|
||||
"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||
code_path.write_text(original_code, encoding="utf-8")
|
||||
|
|
@ -3223,67 +3223,274 @@ class HuggingFaceModel(Model):
|
|||
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
|
||||
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
|
||||
|
||||
def test_top_level_global_assignments() -> None:
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
original_code = '''"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
prompt_engine = PromptEngine("skyvern")
|
||||
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {}
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == ActionType.INPUT_TEXT:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{task_id}:{action_id}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
'''
|
||||
main_file.write_text(original_code, encoding="utf-8")
|
||||
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
from typing import Any, Dict, List
|
||||
import re
|
||||
|
||||
# Precompiled regex for efficiently generating simple field_name from intention
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {{}}
|
||||
|
||||
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
|
||||
intention_cleanup = _INTENTION_CLEANUP_RE
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == input_text_type:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{{task_id}}:{{action_id}}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
# Use compiled regex instead of "".join(c for ...)
|
||||
field_name = intention_cleanup.sub("", field_name)
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
```
|
||||
'''
|
||||
expected = '''"""
|
||||
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.prompting import PromptEngine
|
||||
from skyvern.webeye.actions.actions import ActionType
|
||||
import re
|
||||
|
||||
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
# Initialize prompt engine
|
||||
prompt_engine = PromptEngine("skyvern")
|
||||
|
||||
|
||||
def hydrate_input_text_actions_with_field_names(
|
||||
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Add field_name to input_text actions based on generated mappings.
|
||||
|
||||
Args:
|
||||
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
||||
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
||||
|
||||
Returns:
|
||||
Updated actions_by_task with field_name added to input_text actions
|
||||
"""
|
||||
updated_actions_by_task = {}
|
||||
|
||||
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
|
||||
intention_cleanup = _INTENTION_CLEANUP_RE
|
||||
|
||||
for task_id, actions in actions_by_task.items():
|
||||
updated_actions = []
|
||||
|
||||
for action in actions:
|
||||
action_copy = action.copy()
|
||||
|
||||
if action.get("action_type") == input_text_type:
|
||||
action_id = action.get("action_id", "")
|
||||
mapping_key = f"{task_id}:{action_id}"
|
||||
|
||||
if mapping_key in field_mappings:
|
||||
action_copy["field_name"] = field_mappings[mapping_key]
|
||||
else:
|
||||
# Fallback field name if mapping not found
|
||||
intention = action.get("intention", "")
|
||||
if intention:
|
||||
# Simple field name generation from intention
|
||||
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
||||
# Use compiled regex instead of "".join(c for ...)
|
||||
field_name = intention_cleanup.sub("", field_name)
|
||||
action_copy["field_name"] = field_name or "unknown_field"
|
||||
else:
|
||||
action_copy["field_name"] = "unknown_field"
|
||||
|
||||
updated_actions.append(action_copy)
|
||||
|
||||
updated_actions_by_task[task_id] = updated_actions
|
||||
|
||||
return updated_actions_by_task
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
||||
test_config = TestConfig(
|
||||
tests_root=root_dir / "tests/pytest",
|
||||
tests_project_rootdir=root_dir,
|
||||
project_root_path=root_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
||||
)
|
||||
|
||||
|
||||
new_code = main_file.read_text(encoding="utf-8")
|
||||
main_file.unlink(missing_ok=True)
|
||||
|
||||
assert new_code == expected
|
||||
|
||||
def test_duplicate_global_assignments_when_reverting_helpers():
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
|
||||
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
|
||||
The pre-chunker's responsibilities are:
|
||||
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# -- all detectors need to be called to update state and avoid double counting
|
||||
|
|
@ -3291,29 +3498,24 @@ class PreChunker:
|
|||
# -- Using `any()` would short-circuit on first True.
|
||||
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
|
||||
return any(semantic_boundaries)
|
||||
|
||||
'''
|
||||
main_file.write_text(original_code, encoding="utf-8")
|
||||
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Iterable
|
||||
from unstructured.documents.elements import Element
|
||||
from unstructured.utils import lazyproperty
|
||||
|
||||
class PreChunker:
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
|
|
@ -3352,60 +3554,44 @@ class PreChunker:
|
|||
main_file.unlink(missing_ok=True)
|
||||
|
||||
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
|
||||
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
|
||||
The pre-chunker's responsibilities are:
|
||||
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
|
|
@ -3413,6 +3599,5 @@ class PreChunker:
|
|||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
|
||||
'''
|
||||
assert new_code == expected
|
||||
assert new_code == expected
|
||||
|
|
@ -18,6 +18,8 @@ from collections.abc import Sequence
|
|||
|
||||
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
|
||||
|
||||
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
||||
|
||||
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
||||
if not content:
|
||||
return 0
|
||||
|
|
@ -34,9 +36,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|||
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
||||
""", encoding="utf-8")
|
||||
|
||||
main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
|
@ -131,6 +130,10 @@ from collections.abc import Sequence
|
|||
|
||||
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
|
||||
|
||||
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
|
||||
|
||||
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
||||
|
||||
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
||||
if not content:
|
||||
return 0
|
||||
|
|
@ -155,11 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
|
|||
tokens += len(part.data)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
|
||||
|
||||
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
|
||||
"""
|
||||
|
||||
assert new_code.rstrip() == original_main.rstrip() # No Change
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from codeflash.api.aiservice import AiServiceClient
|
||||
from codeflash.models.models import CodeString
|
||||
|
||||
|
||||
|
|
@ -41,3 +42,30 @@ def test_whitespace_only():
|
|||
whitespace_code = " "
|
||||
cs = CodeString(code=whitespace_code)
|
||||
assert cs.code == whitespace_code
|
||||
|
||||
def test_generated_candidates_validation():
|
||||
ai_service = AiServiceClient()
|
||||
code = """```python:file.py
|
||||
print name
|
||||
```"""
|
||||
mock_generate_candidates = [
|
||||
{
|
||||
"source_code": code,
|
||||
"explanation": "",
|
||||
"optimization_id": ""
|
||||
}
|
||||
]
|
||||
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
|
||||
assert len(candidates) == 0
|
||||
code = """```python:file.py
|
||||
print('Hello, World!')
|
||||
```"""
|
||||
mock_generate_candidates = [
|
||||
{
|
||||
"source_code": code,
|
||||
"explanation": "",
|
||||
"optimization_id": ""
|
||||
}
|
||||
]
|
||||
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
|
||||
assert len(candidates) == 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue