Merge branch 'main' of github.com:codeflash-ai/codeflash into fix/duplicate-global-assignments-when-reverting-helpers

This commit is contained in:
ali 2025-08-25 21:36:35 +03:00
commit 28f50cc1e0
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
11 changed files with 386 additions and 101 deletions

View file

@ -81,6 +81,19 @@ class AiServiceClient:
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
return response
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
candidates: list[OptimizedCandidate] = []
for opt in optimizations_json:
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
if not code.code_strings:
continue
candidates.append(
OptimizedCandidate(
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
)
)
return candidates
def optimize_python_code( # noqa: D417
self,
source_code: str,
@ -135,14 +148,7 @@ class AiServiceClient:
console.rule()
end_time = time.perf_counter()
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
return [
OptimizedCandidate(
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
)
for opt in optimizations_json
]
return self._get_valid_candidates(optimizations_json)
try:
error = response.json()["error"]
except Exception:
@ -205,14 +211,7 @@ class AiServiceClient:
optimizations_json = response.json()["optimizations"]
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
console.rule()
return [
OptimizedCandidate(
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
)
for opt in optimizations_json
]
return self._get_valid_candidates(optimizations_json)
try:
error = response.json()["error"]
except Exception:
@ -262,14 +261,17 @@ class AiServiceClient:
refined_optimizations = response.json()["refinements"]
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
console.rule()
refinements = self._get_valid_candidates(refined_optimizations)
return [
OptimizedCandidate(
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
explanation=opt["explanation"],
optimization_id=opt["optimization_id"][:-4] + "refi",
source_code=c.source_code,
explanation=c.explanation,
optimization_id=c.optimization_id[:-4] + "refi",
)
for opt in refined_optimizations
for c in refinements
]
try:
error = response.json()["error"]
except Exception:

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import ast
from itertools import chain
from typing import TYPE_CHECKING, Optional
import libcst as cst
@ -119,6 +120,32 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def _find_insertion_index(self, updated_node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(updated_node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)
is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)
if is_top_level_import or is_conditional_import:
insert_index = i + 1
# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break
return insert_index
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
@ -131,18 +158,26 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
]
if assignments_to_append:
# Add a blank line before appending new assignments if needed
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
new_statements.pop() # Remove the Pass statement but keep the empty line
# after last top-level imports
insert_index = self._find_insertion_index(updated_node)
# Add the new assignments
new_statements.extend(
[
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for assignment in assignments_to_append
]
)
assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for assignment in assignments_to_append
]
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
# Add a blank line after the last assignment if needed
after_index = insert_index + len(assignment_lines)
if after_index < len(new_statements):
next_stmt = new_statements[after_index]
# If there's no empty line, add one
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
if not has_empty:
new_statements[after_index] = next_stmt.with_changes(
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
)
return updated_node.with_changes(body=new_statements)
@ -341,6 +376,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
new_added_global_statements = extract_global_statements(src_module_code)
existing_global_statements = extract_global_statements(dst_module_code)
# make sure we don't have any staments applited multiple times in the global level.
unique_global_statements = [
stmt
for stmt in new_added_global_statements

View file

@ -412,6 +412,7 @@ def replace_function_definitions_in_module(
module_abspath: Path,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
global_assignments_added_before: bool = False, # noqa: FBT001, FBT002
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@ -421,7 +422,7 @@ def replace_function_definitions_in_module(
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
add_global_assignments(code_to_apply, source_code),
add_global_assignments(code_to_apply, source_code) if not global_assignments_added_before else source_code,
function_names,
code_to_apply,
module_abspath,

View file

@ -537,6 +537,7 @@ def revert_unused_helper_functions(
module_abspath=file_path,
preexisting_objects=set(), # Empty set since we're reverting
project_root_path=project_root,
global_assignments_added_before=True, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
)
if reverted_code:

View file

@ -110,7 +110,7 @@ def initialize_function_optimization(
if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
cleanup_the_optimizer(server)
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
fto = optimizable_funcs.popitem()[1][0]
@ -217,6 +217,7 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
@server.feature("performFunctionOptimization")
@server.thread()
def perform_function_optimization( # noqa: PLR0911
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
@ -337,14 +338,4 @@ def perform_function_optimization( # noqa: PLR0911
"explanation": best_optimization.explanation_v2,
}
finally:
cleanup_the_optimizer(server)
def cleanup_the_optimizer(server: CodeflashLanguageServer) -> None:
server.optimizer.cleanup_temporary_paths()
# restore args and test cfg
if server.optimizer.original_args_and_test_cfg:
server.optimizer.args, server.optimizer.test_cfg = server.optimizer.original_args_and_test_cfg
server.optimizer.args.function = None
server.optimizer.current_worktree = None
server.optimizer.current_function_optimizer = None
server.cleanup_the_optimizer()

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any
from threading import Event
from typing import TYPE_CHECKING, Any, Optional, TextIO
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
from pygls import uris
from pygls.protocol import LanguageServerProtocol, lsp_method
from pygls.server import LanguageServer
from pygls.server import LanguageServer, StdOutTransportAdapter, aio_readline
if TYPE_CHECKING:
from lsprotocol.types import InitializeParams, InitializeResult
@ -81,3 +83,39 @@ class CodeflashLanguageServer(LanguageServer):
# Send log message to client (appears in output channel)
log_params = LogMessageParams(type=lsp_message_type, message=message)
self.lsp.notify("window/logMessage", log_params)
def cleanup_the_optimizer(self) -> None:
try:
self.optimizer.cleanup_temporary_paths()
# restore args and test cfg
if self.optimizer.original_args_and_test_cfg:
self.optimizer.args, self.optimizer.test_cfg = self.optimizer.original_args_and_test_cfg
self.optimizer.args.function = None
self.optimizer.current_worktree = None
self.optimizer.current_function_optimizer = None
except Exception:
self.show_message_log("Failed to cleanup optimizer", "Error")
def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = None) -> None:
self.show_message_log("Starting IO server", "Info")
self._stop_event = Event()
transport = StdOutTransportAdapter(stdin or sys.stdin.buffer, stdout or sys.stdout.buffer)
self.lsp.connection_made(transport)
try:
self.loop.run_until_complete(
aio_readline(
self.loop,
self.thread_pool_executor,
self._stop_event,
stdin or sys.stdin.buffer,
self.lsp.data_received,
)
)
except BrokenPipeError:
self.show_message_log("Connection to the client is lost! Shutting down the server.", "Error")
except (KeyboardInterrupt, SystemExit):
pass
finally:
self.cleanup_the_optimizer()
self.shutdown()

View file

@ -19,7 +19,7 @@ from re import Pattern
from typing import Annotated, Optional, cast
from jedi.api.classes import Name
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import console, logger
@ -239,10 +239,14 @@ class CodeStringsMarkdown(BaseModel):
"""
matches = markdown_pattern.findall(markdown_code)
results = CodeStringsMarkdown()
for file_path, code in matches:
path = file_path.strip()
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
return results
try:
for file_path, code in matches:
path = file_path.strip()
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
return results # noqa: TRY300
except ValidationError:
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
return CodeStringsMarkdown()
class CodeOptimizationContext(BaseModel):

View file

@ -1354,6 +1354,7 @@ class FunctionOptimizer:
return
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
logger.info("Reverting code and helpers...")
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)

View file

@ -2099,6 +2099,8 @@ print("Hello world")
"""
expected_code = """import numpy as np
a = 6
if 2<3:
a=4
else:
@ -2120,8 +2122,6 @@ class NewClass:
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
a = 6
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
@ -3223,67 +3223,274 @@ class HuggingFaceModel(Model):
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
def test_top_level_global_assignments() -> None:
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
original_code = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
from typing import Any, Dict, List, Tuple
import structlog
from pydantic import BaseModel
from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == ActionType.INPUT_TEXT:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
from skyvern.webeye.actions.actions import ActionType
from typing import Any, Dict, List
import re
# Precompiled regex for efficiently generating simple field_name from intention
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {{}}
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
intention_cleanup = _INTENTION_CLEANUP_RE
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == input_text_type:
action_id = action.get("action_id", "")
mapping_key = f"{{task_id}}:{{action_id}}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
# Use compiled regex instead of "".join(c for ...)
field_name = intention_cleanup.sub("", field_name)
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
```
'''
expected = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""
from typing import Any, Dict, List, Tuple
import structlog
from pydantic import BaseModel
from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
import re
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
LOG = structlog.get_logger(__name__)
# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")
def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.
Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names
Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
intention_cleanup = _INTENTION_CLEANUP_RE
for task_id, actions in actions_by_task.items():
updated_actions = []
for action in actions:
action_copy = action.copy()
if action.get("action_type") == input_text_type:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"
if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
# Use compiled regex instead of "".join(c for ...)
field_name = intention_cleanup.sub("", field_name)
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"
updated_actions.append(action_copy)
updated_actions_by_task[task_id] = updated_actions
return updated_actions_by_task
'''
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir,
project_root_path=root_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
)
new_code = main_file.read_text(encoding="utf-8")
main_file.unlink(missing_ok=True)
assert new_code == expected
def test_duplicate_global_assignments_when_reverting_helpers():
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# -- all detectors need to be called to update state and avoid double counting
@ -3291,29 +3498,24 @@ class PreChunker:
# -- Using `any()` would short-circuit on first True.
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
return any(semantic_boundaries)
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
from __future__ import annotations
from typing import Iterable
from unstructured.documents.elements import Element
from unstructured.utils import lazyproperty
class PreChunker:
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
@ -3352,60 +3554,44 @@ class PreChunker:
main_file.unlink(missing_ok=True)
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
@ -3413,6 +3599,5 @@ class PreChunker:
if pred(element):
return True
return False
'''
assert new_code == expected
assert new_code == expected

View file

@ -18,6 +18,8 @@ from collections.abc import Sequence
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0
@ -34,9 +36,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
return tokens
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
""", encoding="utf-8")
main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
@ -131,6 +130,10 @@ from collections.abc import Sequence
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0
@ -155,11 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
tokens += len(part.data)
return tokens
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
"""
assert new_code.rstrip() == original_main.rstrip() # No Change

View file

@ -1,6 +1,7 @@
import pytest
from pydantic import ValidationError
from codeflash.api.aiservice import AiServiceClient
from codeflash.models.models import CodeString
@ -41,3 +42,30 @@ def test_whitespace_only():
whitespace_code = " "
cs = CodeString(code=whitespace_code)
assert cs.code == whitespace_code
def test_generated_candidates_validation():
ai_service = AiServiceClient()
code = """```python:file.py
print name
```"""
mock_generate_candidates = [
{
"source_code": code,
"explanation": "",
"optimization_id": ""
}
]
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
assert len(candidates) == 0
code = """```python:file.py
print('Hello, World!')
```"""
mock_generate_candidates = [
{
"source_code": code,
"explanation": "",
"optimization_id": ""
}
]
candidates = ai_service._get_valid_candidates(mock_generate_candidates)
assert len(candidates) == 1