Merge branch 'main' into multi-language

This commit is contained in:
Kevin Turcios 2026-01-26 23:59:03 -05:00 committed by GitHub
commit 764a3f8899
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1032 additions and 185 deletions

View file

@ -89,12 +89,19 @@ jobs:
- name: Install Project Dependencies
run: |
uv sync --refresh --isolated --active
uv pip uninstall codeflash
uv pip install pytest-asyncio
uv pip install black
- name: Install Codeflash in separate venv
working-directory: .
run: |
uv venv .codeflash-venv
source .codeflash-venv/bin/activate
uv pip install pytest-asyncio black
uv pip install git+https://github.com/codeflash-ai/codeflash@main
- name: Run Codeflash to optimize code
id: optimize_code
working-directory: .
run: |
uv run --active codeflash --async --verbose
source .codeflash-venv/bin/activate
cd django/aiservice
codeflash --async --verbose

1
.gitignore vendored
View file

@ -162,6 +162,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.aider*
.serena/
/js/common/node_modules/
/node_modules/
*.xml

View file

@ -156,7 +156,7 @@ def make_number_node(value: int) -> cst.Integer | cst.UnaryOperation:
def evaluate_expression(node: cst.BaseExpression) -> int | None:
"""Evaluate a CST expression node to get its integer value, or None if not evaluable."""
if isinstance(node, cst.Integer):
return int(node.value)
return int(node.value, 0)
if isinstance(node, cst.Float):
return int(float(node.value))
if isinstance(node, cst.BinaryOperation):

View file

@ -70,7 +70,9 @@ def split_markdown_code(markdown: str, language: str = "python") -> dict[str, st
matches = pattern.findall(markdown)
result: dict[str, str] = {}
for file_path, code in matches:
result.setdefault(file_path.strip(), code)
stripped_path = file_path.strip()
if stripped_path not in result:
result[stripped_path] = code
return result
@ -79,7 +81,18 @@ def extract_code_block_with_context(text: str, language: str = "python") -> tupl
Returns (before_text, code_content, after_text) or None if not found.
Handles optional filepath syntax (e.g., ```python:filepath.py).
Prefers code blocks with filepath annotations over plain code blocks,
as these are more likely to contain the actual code in multi-block responses
where the LLM may include example or explanation blocks before the real code.
"""
# First try to find a code block with filepath (e.g., ```python:file.py)
# These are more likely to contain the actual optimized code
pattern_with_path = rf"(.*?)```{language}:[^\n]+(?:\n|\\n)(.*?)```(.*)"
if match := re.match(pattern_with_path, text, re.DOTALL | re.MULTILINE):
return match.group(1).strip(), match.group(2), match.group(3).strip()
# Fall back to any code block (with or without filepath)
pattern = rf"(.*?)```{language}(?::[^\n]*)?(?:\n|\\n)(.*?)```(.*)"
if match := re.match(pattern, text, re.DOTALL | re.MULTILINE):
return match.group(1).strip(), match.group(2), match.group(3).strip()

View file

@ -235,6 +235,14 @@ async def jit_rewrite(
if len(jit_rewrite_response_items) == 0:
await asyncio.to_thread(ph, request.user, "aiservice-jit-rewrite-no-optimizations-found")
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
logging.error(
"Could not generate any optimizations (jit_rewrite). trace_id=%s, repo=%s/%s, n_candidates=%d, source_len=%d",
data.trace_id,
data.repo_owner,
data.repo_name,
data.n_candidates,
len(data.source_code) if data.source_code else 0,
)
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
await asyncio.to_thread(
ph,

View file

@ -19,12 +19,7 @@ from optimizer.context_utils.constants import (
LINE_PROF_CONTEXT_PROMPT,
MARKDOWN_CONTEXT_PROMPT,
)
from optimizer.context_utils.context_helpers import (
group_code,
is_markdown_structure_changed,
is_multi_context,
split_markdown_code,
)
from optimizer.context_utils.context_helpers import group_code, is_multi_context, split_markdown_code
from optimizer.diff_patches_utils.diff import SEARCH_AND_REPLACE_FORMAT_PROMPT, V4A_DIFF_FORMAT_PROMPT, DiffMethod
from optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
from optimizer.diff_patches_utils.v4a_diff import V4ADiff
@ -206,6 +201,8 @@ class MultiOptimizerContext(BaseOptimizerContext):
super().__init__(base_system_prompt, base_user_prompt, source_code)
self.original_file_to_code = split_markdown_code(source_code)
self._original_files_set = set(self.original_file_to_code.keys())
def get_system_prompt(self, python_version_str: str) -> str:
# critical instructions should be bullet points, short and direct, further instructions should be a full paragraph which is added after the main system prompt
critical_instructions = ""
@ -274,8 +271,8 @@ class MultiOptimizerContext(BaseOptimizerContext):
if markdown_code == "":
return None
original_code_files = self.original_file_to_code.items()
extracted_code_files = split_markdown_code(markdown_code).items()
original_code_files = dict(self.original_file_to_code.items())
extracted_code_files = dict(split_markdown_code(markdown_code).items())
op_id = str(uuid.uuid4())
new_post_processed_code: dict[str, str] = {}
@ -285,26 +282,34 @@ class MultiOptimizerContext(BaseOptimizerContext):
code=extracted.code, explanation=extracted.explanation
)
if len(original_code_files) != len(extracted_code_files):
# Check if any extracted files exist in the original context
# LLMs may only return files they actually modified
valid_extracted_files = {f: c for f, c in extracted_code_files.items() if f in original_code_files}
if not valid_extracted_files:
debug_log_sensitive_data(
f"original files {original_code_files} are not the same as extracted files {extracted_code_files}"
f"No matching files found. Original: {list(original_code_files.keys())}, "
f"Extracted: {list(extracted_code_files.keys())}"
)
return None
# Log when LLM returns fewer files (common with large contexts)
if len(valid_extracted_files) < len(original_code_files):
debug_log_sensitive_data(
f"LLM returned {len(valid_extracted_files)} of {len(original_code_files)} files. "
f"Processing returned files: {list(valid_extracted_files.keys())}"
)
# optimization is ignored when **all** files in the context are rejected by postprocessing pipeline
optimization_ignored = True
# post process each file
for (original_file, original_code), (new_file, new_code) in zip(
original_code_files, extracted_code_files, strict=True
):
if original_file != new_file:
# shouldn't happen but just in case
continue
# post process each file that was returned by the LLM
for original_file, new_code in valid_extracted_files.items():
original_code = original_code_files[original_file]
try:
new_cst_module = parse_module_to_cst(new_code)
code_and_explanation = CodeExplanationAndID(
cst_module=new_cst_module, id=f"{new_file}:post-processing", explanation=extracted.explanation
cst_module=new_cst_module, id=f"{original_file}:post-processing", explanation=extracted.explanation
)
postprocessed_list: list[CodeExplanationAndID] = optimizations_postprocessing_pipeline(
original_code, [code_and_explanation]
@ -344,12 +349,18 @@ class MultiOptimizerContext(BaseOptimizerContext):
def is_valid_code(self) -> bool:
if self.extracted_code_and_expl is None:
return False
changed = is_markdown_structure_changed(self.source_code, self.extracted_code_and_expl.code)
if changed:
extracted_files = set(split_markdown_code(self.extracted_code_and_expl.code).keys())
# All extracted files must be in the original set (no new files allowed)
invalid_files = extracted_files - self._original_files_set
if invalid_files:
debug_log_sensitive_data(
f"code markdown have been changed:=======\n{self.source_code}\\=======\n{self.extracted_code_and_expl.code}=======\n"
f"LLM returned files not in original context: {invalid_files}. "
f"Original: {self._original_files_set}, Extracted: {extracted_files}"
)
return super().is_valid_code() and (not changed)
return False
return super().is_valid_code()
def validate_and_parse_source_code(self, code: str, feature_version: tuple) -> None:
final_code = code or self.source_code

View file

@ -377,7 +377,14 @@ async def optimize_python(
if len(optimization_response_items) == 0:
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
logging.error(f"Could not generate any optimizations. trace_id={data.trace_id}")
logging.error(
"Could not generate any optimizations. trace_id=%s, repo=%s/%s, n_candidates=%d, source_len=%d",
data.trace_id,
data.repo_owner,
data.repo_name,
data.n_candidates,
len(data.source_code) if data.source_code else 0,
)
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
await asyncio.to_thread(
ph,

View file

@ -660,7 +660,13 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
if len(optimization_response_items) == 0:
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
logging.error(f"Could not generate any optimizations (line profiler). trace_id={data.trace_id}")
logging.error(
"Could not generate any optimizations (line profiler). trace_id=%s, n_candidates=%d, source_len=%d, has_line_profiler=%s",
data.trace_id,
data.n_candidates,
len(data.source_code) if data.source_code else 0,
bool(data.line_profiler_results),
)
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
await asyncio.to_thread(
ph,

View file

@ -25,6 +25,7 @@ dependencies = [
"stamina>=25.1.0",
"jedi>=0.19.2",
"anthropic>=0.75.0",
"wcwidth>=0.2.15",
]
[project.urls]

View file

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, SimpleStatementLine, parse_module
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, SimpleStatementLine
from aiservice.common.cst_utils import DepthTrackingMixin, build_module_path, parse_module_to_cst
@ -57,52 +57,56 @@ def replace_first_function_definition(module_str: str, function_name: str, stmt:
return modified_module_cst.code
def replace_definition_with_import(source_code: str, function: FunctionToOptimize, module_path: str) -> str: # noqa: D417
"""Replace a function or class definition with an import statement in the given CST tree.
class _ImportInserter(DepthTrackingMixin, CSTTransformer):
"""Transformer that replaces a function or class definition with an import statement."""
def __init__(self, target_name: str, import_stmt: SimpleStatementLine) -> None:
DepthTrackingMixin.__init__(self)
CSTTransformer.__init__(self)
self.target_name = target_name
self.import_stmt = import_stmt
self.replaced_count = 0
def visit_ClassDef(self, node: ClassDef) -> None: # noqa: ARG002
self._visit_class()
def leave_ClassDef(
self, original_node: ClassDef, updated_node: ClassDef
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
if original_node.name.value == self.target_name and self._class_depth == 1 and self.replaced_count == 0:
self.replaced_count += 1
self._leave_class()
return self.import_stmt
self._leave_class()
return updated_node
def leave_FunctionDef(
self, original_node: FunctionDef, updated_node: FunctionDef
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
if original_node.name.value == self.target_name and self._class_depth == 0 and self.replaced_count == 0:
self.replaced_count += 1
return self.import_stmt
return updated_node
def replace_definition_with_import(module: Module, function: FunctionToOptimize, module_path: str) -> Module:
"""Replace a function or class definition with an import statement.
Parameters
----------
- tree (CSTModule): The CST tree representing the module.
- function (FunctionToOptimize): The function to replace with an import.
- module_path (str): Full module path of the function or class being replaced
module : Module
The CST module representing the test code.
function : FunctionToOptimize
The function to replace with an import.
module_path : str
Full module path of the function or class being replaced.
Returns
-------
- CSTModule: The modified CST tree.
Module
The modified CST module.
"""
class ImportInserter(DepthTrackingMixin, CSTTransformer):
def __init__(self, target_name: str, import_stmt: SimpleStatementLine) -> None:
DepthTrackingMixin.__init__(self)
CSTTransformer.__init__(self)
self.target_name = target_name
self.import_stmt = import_stmt
self.replaced_count = 0
def visit_ClassDef(self, node: ClassDef) -> None: # noqa: ARG002
self._visit_class()
def leave_ClassDef(
self, original_node: ClassDef, updated_node: ClassDef
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
if original_node.name.value == self.target_name and self._class_depth == 1 and self.replaced_count == 0:
self.replaced_count += 1
self._leave_class()
return self.import_stmt
self._leave_class()
return updated_node
def leave_FunctionDef(
self, original_node: FunctionDef, updated_node: FunctionDef
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
if original_node.name.value == self.target_name and self._class_depth == 0 and self.replaced_count == 0:
self.replaced_count += 1
return self.import_stmt
return updated_node
tree = parse_module(source_code)
if not function.parents:
new_import = SimpleStatementLine(
body=[
@ -111,8 +115,8 @@ def replace_definition_with_import(source_code: str, function: FunctionToOptimiz
)
]
)
transformer = ImportInserter(function.function_name, new_import)
return tree.visit(transformer).code
transformer = _ImportInserter(function.function_name, new_import)
return module.visit(transformer)
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
new_import = SimpleStatementLine(
body=[
@ -122,6 +126,6 @@ def replace_definition_with_import(source_code: str, function: FunctionToOptimiz
)
]
)
transformer = ImportInserter(function.parents[0].name, new_import)
return tree.visit(transformer).code
return tree.code
transformer = _ImportInserter(function.parents[0].name, new_import)
return module.visit(transformer)
return module

View file

@ -555,7 +555,15 @@ def add_missing_imports_from_multi_context_source(
return AddImportsVisitor(context).transform_module(module)
except Exception as e: # noqa: BLE001
logging.warning("add_missing_imports_from_multi_context_source failed: %s", e)
source_files = list(source_code_blocks.keys()) if source_code_blocks else []
logging.warning(
"add_missing_imports_from_multi_context_source failed: %s, "
"default_module_path=%s, source_files=%s, module_is_none=%s",
e,
default_module_path,
source_files,
module is None,
)
sentry_sdk.capture_exception(e)
return module

View file

@ -1,8 +1,13 @@
from typing import TYPE_CHECKING, Any
from __future__ import annotations
from libcst import Module
from typing import TYPE_CHECKING
from aiservice.models.functions_to_optimize import FunctionToOptimize
from optimizer.context_utils.context_helpers import is_multi_context, split_markdown_code
from testgen.instrumentation.edit_generated_test import replace_definition_with_import
from testgen.postprocessing.add_missing_imports import (
add_missing_imports_from_multi_context_source,
add_missing_imports_from_source,
)
from testgen.postprocessing.range_modifier import modify_large_loops
from testgen.postprocessing.remove_unused_definitions import remove_unused_definitions_from_pytest_file
from testgen.postprocessing.removeassert_transformer import remove_asserts_from_test
@ -11,27 +16,56 @@ from testgen.postprocessing.topdef_terminator import delete_top_def_nodes
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from libcst import Module
from aiservice.models.functions_to_optimize import FunctionToOptimize
def postprocessing_testgen_pipeline(
module: Module, helper_function_names: list[str], function_to_optimize: FunctionToOptimize, module_path: str
module: Module,
helper_function_names: list[str],
function_to_optimize: FunctionToOptimize,
module_path: str,
source_code_being_tested: str,
) -> Module:
"""Apply postprocessing functions in a pipeline, a list of (function, kwargs)."""
"""Full postprocessing pipeline for generated test code.
Applies all CST transformations in sequence:
1. Clean up definitions (remove helper functions, unused definitions)
2. Modify constructs (large loops, tensors)
3. Remove asserts
4. Add missing imports
5. Replace function definition with import
"""
add_imports_func, add_imports_kwargs = (
(
add_missing_imports_from_multi_context_source,
{"source_code_blocks": split_markdown_code(source_code_being_tested), "default_module_path": module_path},
)
if is_multi_context(source_code_being_tested)
else (add_missing_imports_from_source, {"source_code": source_code_being_tested, "module_path": module_path})
)
pipeline: list[tuple[Callable[..., Module], dict[str, Any]]] = [
(delete_top_def_nodes, {"deletable_list": helper_function_names}),
(remove_unused_definitions_from_pytest_file, {}),
(modify_large_loops, {}),
(modify_tensors, {}),
(
remove_asserts_from_test, # should be the last function in the pipeline since it transforms the code from result = extract_input_variables(nodes) to codeflash_output = extract_input_variables(nodes); result = codeflash_output and raises an error, probably because of the semicolon
remove_asserts_from_test,
{
"function_to_optimize": function_to_optimize,
"helper_function_names": helper_function_names,
"module_path": module_path,
},
),
(add_imports_func, add_imports_kwargs),
(replace_definition_with_import, {"function": function_to_optimize, "module_path": module_path}),
]
for func, kwargs in pipeline:
module = func(module, **kwargs)
return module

View file

@ -28,6 +28,7 @@ from libcst import (
SimpleStatementSuite,
UnaryOperation,
parse_expression,
parse_module,
)
if TYPE_CHECKING:
@ -44,7 +45,11 @@ class StatementHandler:
self, updated_node: SimpleStatementLine | SimpleStatementSuite
) -> SimpleStatementLine | SimpleStatementSuite | RemovalSentinel | FlattenSentinel[BaseSmallStatement] | None:
if not updated_node.body:
return updated_node
# Empty body is malformed - handle it properly
if isinstance(updated_node, SimpleStatementSuite):
return updated_node.with_changes(body=[Pass()])
# SimpleStatementLine with empty body should be removed
return RemoveFromParent()
if len(updated_node.body) == 1:
return self._handle_single_statement(updated_node)
@ -90,8 +95,12 @@ class StatementHandler:
if new_body and hasattr(new_body[-1], "semicolon"):
new_body[-1] = new_body[-1].with_changes(semicolon=None)
if isinstance(updated_node, SimpleStatementSuite) and not new_body:
new_body = [Pass()]
if not new_body:
if isinstance(updated_node, SimpleStatementSuite):
new_body = [Pass()]
else:
# SimpleStatementLine with empty body is malformed - remove it entirely
return RemoveFromParent()
return updated_node.with_changes(body=new_body)
@ -373,8 +382,27 @@ class RemoveAssertTransformer(CSTTransformer):
) -> Any: # noqa: ANN401
return self.statement_handler.handle_statement(updated_node)
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> RemovalSentinel | FunctionDef: # noqa: ARG002
return updated_node if updated_node.body.body else RemoveFromParent()
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> RemovalSentinel | FunctionDef:
body_block = updated_node.body
# For IndentedBlock (regular multi-line functions)
if hasattr(body_block, "body") and isinstance(body_block.body, (list, tuple)):
body = body_block.body
if not body:
return RemoveFromParent()
# Only remove test functions that only contain pass (useless test stubs)
# Non-test functions with pass are legitimate and should be kept
if original_node.name.value.startswith("test_") and len(body) == 1:
first = body[0]
# Check if it's a SimpleStatementLine with just Pass()
if isinstance(first, SimpleStatementLine) and len(first.body) == 1 and isinstance(first.body[0], Pass):
return RemoveFromParent()
# Check if it's a SimpleStatementSuite (one-liner) with just Pass()
if isinstance(first, Pass):
return RemoveFromParent()
return updated_node
def is_target_function_node(self, node: Call) -> bool:
return (isinstance(node.func, Name) and node.func.value == self.only_function_name) or (
@ -463,4 +491,7 @@ def remove_asserts_from_test(
]
)
new_body = [add_import, *list(modified_tree.body)]
return modified_tree.with_changes(body=tuple(new_body))
result = modified_tree.with_changes(body=tuple(new_body))
# Re-parse to ensure valid CST structure. The transformer may create nodes with
# semicolon=None instead of MaybeSentinel, which causes errors during subsequent visits.
return parse_module(result.code)

View file

@ -23,7 +23,8 @@ from aiservice.models.functions_to_optimize import FunctionToOptimize
from authapp.auth import AuthenticatedRequest
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from optimizer.context_utils.context_helpers import is_multi_context, split_markdown_code
from libcst import parse_module
from testgen.instrumentation.edit_generated_test import replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
from testgen.models import (
@ -34,10 +35,6 @@ from testgen.models import (
TestGenSchema,
TestingMode,
)
from testgen.postprocessing.add_missing_imports import (
add_missing_imports_from_multi_context_source,
add_missing_imports_from_source,
)
from testgen.postprocessing.code_validator import CodeValidationError, has_test_functions, validate_testgen_code
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
from testgen.testgen_context import BaseTestGenContext, TestGenContextData
@ -205,7 +202,7 @@ def parse_and_validate_llm_output(
if function_to_optimize is not None and module_path is not None:
try:
code = replace_definition_with_import(code, function_to_optimize, module_path)
code = replace_definition_with_import(parse_module(code), function_to_optimize, module_path).code
except Exception: # noqa: BLE001
# If replacement fails (e.g., parsing error), continue with original code
logging.warning("replace_definition_with_import failed, continuing with original code")
@ -330,24 +327,10 @@ async def generate_regression_tests_from_function(
data.helper_function_names or [],
data.function_to_optimize,
data.module_path,
data.source_code_being_tested,
)
# Add missing imports for symbols defined in source module but not imported in test.
# This handles cases where the LLM redefines some classes locally but forgets others.
if is_multi_context(data.source_code_being_tested):
# Multi-context: use per-file module mapping for correct imports
source_code_blocks = split_markdown_code(data.source_code_being_tested)
processed_cst = add_missing_imports_from_multi_context_source(
processed_cst, source_code_blocks, data.module_path
)
else:
processed_cst = add_missing_imports_from_source(
processed_cst, data.source_code_being_tested, data.module_path
)
generated_test_source = replace_definition_with_import(
processed_cst.code, data.function_to_optimize, data.module_path
)
generated_test_source = processed_cst.code
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
generated_test_source, data, python_version

View file

@ -50,6 +50,27 @@ def test_evaluate_expression_integer() -> None:
assert evaluate_expression(node) == 42
def test_evaluate_expression_hexadecimal() -> None:
"""Test that hexadecimal integer literals (0x prefix) are correctly evaluated."""
assert evaluate_expression(cst.Integer("0x00")) == 0
assert evaluate_expression(cst.Integer("0xFF")) == 255
assert evaluate_expression(cst.Integer("0xDEADBEEF")) == 0xDEADBEEF
def test_evaluate_expression_octal() -> None:
"""Test that octal integer literals (0o prefix) are correctly evaluated."""
assert evaluate_expression(cst.Integer("0o0")) == 0
assert evaluate_expression(cst.Integer("0o77")) == 63
assert evaluate_expression(cst.Integer("0o755")) == 493
def test_evaluate_expression_binary() -> None:
"""Test that binary integer literals (0b prefix) are correctly evaluated."""
assert evaluate_expression(cst.Integer("0b0")) == 0
assert evaluate_expression(cst.Integer("0b1010")) == 10
assert evaluate_expression(cst.Integer("0b11111111")) == 255
def test_evaluate_expression_float() -> None:
node = cst.Float("3.14")
assert evaluate_expression(node) == 3

View file

@ -248,3 +248,86 @@ def test_wrap_and_extract_roundtrip() -> None:
wrapped = wrap_code_in_markdown(original_code)
extracted = extract_code_block(wrapped)
assert extracted == original_code
def test_extract_code_block_with_context_prefers_filepath_block() -> None:
"""Test that filepath-annotated blocks are preferred over plain blocks.
This handles LLM responses that include explanation code blocks before
the actual optimized code with a filepath annotation. This test is based on
an actual failing JIT rewrite response (trace 59be9279-dd59-4fad-aaf4-6bc541e44cb8)
where the LLM returned an explanation with example code before the real optimized code.
"""
# Actual response pattern from gpt-5-mini JIT rewrite that caused extraction failure
text = """**Optimization Explanation:**
I applied targeted micro-optimizations that reduce attribute lookups.
```python
# Example of the original slow pattern:
def slow_convert(sequence):
result = []
for item in sequence:
result.append(tuple(item))
return tuple(result)
```
The optimized code uses list comprehension for better performance:
```python:unstructured/documents/elements.py
from __future__ import annotations
import dataclasses as dc
@dc.dataclass
class DataSourceMetadata:
url: str | None = None
```
These yield measurable runtime improvements."""
result = extract_code_block_with_context(text)
assert result is not None
before, code, after = result
# Should extract the filepath-annotated block, not the example block
assert (
before
== """**Optimization Explanation:**
I applied targeted micro-optimizations that reduce attribute lookups.
```python
# Example of the original slow pattern:
def slow_convert(sequence):
result = []
for item in sequence:
result.append(tuple(item))
return tuple(result)
```
The optimized code uses list comprehension for better performance:"""
)
assert (
code
== """from __future__ import annotations
import dataclasses as dc
@dc.dataclass
class DataSourceMetadata:
url: str | None = None
"""
)
assert after == "These yield measurable runtime improvements."
def test_extract_code_block_with_context_falls_back_to_plain_block() -> None:
"""Test fallback to plain block when no filepath-annotated block exists."""
text = """Here's the code:
```python
def simple_function():
return 42
```
"""
result = extract_code_block_with_context(text)
assert result is not None
before, code, after = result
assert before == "Here's the code:"
assert code == "def simple_function():\n return 42\n"
assert after == ""

View file

@ -703,3 +703,126 @@ def gcd(a: int, b: int) -> int:
assert ctx.extracted_code_and_expl.code == expected_code
assert ctx.extracted_code_and_expl.explanation == "I replaced the recursive implementation with an iterative loop."
assert ctx.is_valid_code()
def test_multi_optimizer_partial_file_return() -> None:
"""Test that LLM returning only modified files (not all original files) is handled correctly.
This tests the fix for trace_id=5e4306f4-a909-48b1-aeb4-932c1b451129 where the optimizer
received 7 files but LLMs only returned 1-2 files they actually modified. The strict
validation previously required exact file count matching, causing all candidates to fail.
"""
# Original code has 3 files
original_code = """```python:app/utils.py
def helper():
return "helper"
```
```python:app/main.py
from app.utils import helper
def process(data):
for i in range(len(data)):
data[i] = data[i] * 2
return data
```
```python:app/constants.py
MAX_SIZE = 100
DEFAULT_VALUE = 0
```
"""
# LLM only returns 1 file (the one it actually modified)
llm_response = """I optimized the loop using list comprehension for better performance.
```python:app/main.py
from app.utils import helper
def process(data):
return [x * 2 for x in data]
```
"""
ctx = create_optimizer_context(original_code)
assert isinstance(ctx, MultiOptimizerContext)
ctx.extract_code_and_explanation_from_llm_res(llm_response)
# Verify extraction worked correctly
assert ctx.extracted_code_and_expl is not None
assert (
ctx.extracted_code_and_expl.code
== """```python:app/main.py
from app.utils import helper
def process(data):
return [x * 2 for x in data]
```"""
)
assert (
ctx.extracted_code_and_expl.explanation
== "I optimized the loop using list comprehension for better performance."
)
# Should succeed - partial return is valid
assert ctx.is_valid_code() is True
candidate = ctx.parse_and_generate_candidate_schema()
assert candidate is not None
assert candidate.optimization_id is not None
# The returned code should only contain the modified file (with post-processing applied)
assert (
candidate.source_code
== """```python:app/main.py
from app.utils import helper
def process(data):
return [x * 2 for x in data]
```"""
)
assert candidate.explanation == "I optimized the loop using list comprehension for better performance."
def test_multi_optimizer_rejects_new_files_not_in_original() -> None:
"""Test that LLM returning files NOT in the original context is rejected.
While partial returns (subset of original files) are allowed, returning
new files that weren't in the original context should be rejected.
"""
# Need at least 2 files to get a MultiOptimizerContext
original_code = """```python:app/main.py
def process(data):
return data * 2
```
```python:app/utils.py
def helper():
return "helper"
```
"""
# LLM returns a file that wasn't in the original context
llm_response = """I added a new helper module.
```python:app/new_helpers.py
def double(x):
return x * 2
```
"""
ctx = create_optimizer_context(original_code)
assert isinstance(ctx, MultiOptimizerContext)
ctx.extract_code_and_explanation_from_llm_res(llm_response)
# Verify extraction worked (the code was extracted, just not valid)
assert ctx.extracted_code_and_expl is not None
assert (
ctx.extracted_code_and_expl.code
== """```python:app/new_helpers.py
def double(x):
return x * 2
```"""
)
assert ctx.extracted_code_and_expl.explanation == "I added a new helper module."
# Should fail - new file not in original context
assert ctx.is_valid_code() is False
# parse_and_generate_candidate_schema should return None for invalid files
candidate = ctx.parse_and_generate_candidate_schema()
assert candidate is None

View file

@ -1,5 +1,7 @@
import ast
from libcst import parse_module
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
from testgen.instrumentation.edit_generated_test import replace_definition_with_import
@ -24,7 +26,7 @@ def test_something():
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = """
import pytest
@ -68,7 +70,7 @@ def test_something():
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = '''
import pytest
@ -118,7 +120,7 @@ async def clone_agent_template(request, template_id):
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = '''
import pytest
@ -163,7 +165,7 @@ class OuterClass:
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = '''
import pytest
@ -207,7 +209,7 @@ def test_something():
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = """
import pytest
@ -241,7 +243,7 @@ class SomeClass:
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = """
import pytest
@ -286,7 +288,7 @@ class MockLogger:
)
module_path = "app.api.routes.agent"
result = replace_definition_with_import(source_code, function, module_path)
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
expected = """
import pytest

View file

@ -1,7 +1,12 @@
from libcst import Pass, RemoveFromParent, SimpleStatementLine, SimpleStatementSuite
from libcst import parse_module as parse_module_to_cst
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
from testgen.postprocessing.removeassert_transformer import remove_asserts_from_test
from testgen.postprocessing.removeassert_transformer import (
RemoveAssertTransformer,
StatementHandler,
remove_asserts_from_test,
)
def test_remove_asserts() -> None:
@ -796,3 +801,90 @@ class TestPigLatin(unittest.TestCase):
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_empty_simple_statement_line_removed() -> None:
"""Test that a SimpleStatementLine with all statements removed is properly removed.
This tests the fix for the bug where an empty SimpleStatementLine would create
a malformed CST that could cause 'NoneType' object has no attribute 'visit' errors.
When all asserts on a line are irrelevant (not related to the function being optimized),
the entire line should be removed rather than leaving an empty body.
"""
original_test = """def test_multiple_irrelevant_asserts():
x = 5
assert 1 == 1; assert True; assert 2 > 1
y = 10
"""
expected = """from some_file import some_function
def test_multiple_irrelevant_asserts():
x = 5
y = 10
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_remove_asserts_one_line_function_all_removed() -> None:
"""Test that one-line functions with only irrelevant asserts are removed entirely.
When a one-line function body (like `def test(): assert x`) has all its statements
removed (because they're irrelevant asserts), the entire function should be removed
rather than leaving a useless `def test(): pass` stub.
"""
original_test = """def test_irrelevant(): assert 1 == 1"""
expected = """from some_file import some_function
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_handle_statement_with_empty_body() -> None:
"""Test that handle_statement properly handles nodes with empty body.
This tests the fix for the bug where a SimpleStatementLine with an empty body
(from prior transformations or malformed input) would be returned unchanged,
causing CST tree corruption and 'NoneType' object has no attribute 'visit' errors
in subsequent transforms.
- SimpleStatementLine with empty body should return RemoveFromParent()
- SimpleStatementSuite with empty body should return a node with Pass()
"""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
transformer = RemoveAssertTransformer(function_to_optimize, [])
handler = StatementHandler(transformer)
# Test SimpleStatementLine with empty body returns RemoveFromParent()
empty_line = SimpleStatementLine(body=[])
result = handler.handle_statement(empty_line)
assert result == RemoveFromParent()
# Test SimpleStatementSuite with empty body returns node with Pass()
empty_suite = SimpleStatementSuite(body=[])
result = handler.handle_statement(empty_suite)
assert isinstance(result, SimpleStatementSuite)
assert len(result.body) == 1
assert isinstance(result.body[0], Pass)

View file

@ -1,6 +1,8 @@
from libcst import parse_module as parse_module_to_cst
from aiservice.models.functions_to_optimize import FunctionToOptimize
from optimizer.context_utils.context_helpers import group_code
from testgen.postprocessing.code_validator import validate_testgen_code
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
@ -67,6 +69,13 @@ def test_few_shot_variables():
'''
module = parse_module_to_cst(code)
# Source code that defines the function being tested
source_code_being_tested = '''
def extract_input_variables(nodes):
"""Extracts input variables from the template."""
pass
'''
function_to_optimize = FunctionToOptimize(
function_name="extract_input_variables",
file_path="testgen/postprocessing/tests/test_validate_pipeline.py",
@ -76,44 +85,19 @@ def test_few_shot_variables():
)
module_path = "test_validate_pipeline"
result = postprocessing_testgen_pipeline(
module, ["function_to_remove"], function_to_optimize, module_path
) # function_to_remove is in the deletable_list, execute delete_top_def_nodes in the pipeline
module, ["function_to_remove"], function_to_optimize, module_path, source_code_being_tested
)
expected = r'''from test_validate_pipeline import extract_input_variables
# After consolidation, the function definition is removed by add_missing_imports_from_source
# (which detects local redefinitions of public symbols) and replaced with an import.
# The import is added at the top by AddImportsVisitor.
expected = r"""from test_validate_pipeline import extract_input_variables
import re
from typing import Any
# imports
import pytest # used for our unit tests
def extract_input_variables(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extracts input variables from the template and adds them to the input_variables field."""
prompt_pattern = re.compile(r"\{(.*?)\}")
for node in nodes:
try:
data_node = node["data"]["node"]
template_info = data_node["template"]
template_type = template_info["_type"]
if "input_variables" in template_info:
if template_type == "prompt":
value = template_info["template"]["value"]
variables = prompt_pattern.findall(value)
elif template_type == "few_shot":
prefix = template_info["prefix"]["value"]
suffix = template_info["suffix"]["value"]
variables = prompt_pattern.findall(prefix + suffix)
else:
variables = []
template_info["input_variables"]["value"] = variables
except (KeyError, TypeError):
# Exception suppressed as in the original code
pass
return nodes
# unit tests
@ -129,6 +113,465 @@ def test_few_shot_variables():
nodes = [{"data": {"node": {"template": {"_type": "few_shot", "prefix": {"value": "{var1}"}, "suffix": {"value": "{var2}"}, "input_variables": {}}}}}]
codeflash_output = extract_input_variables(nodes); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
'''
"""
assert result.code == expected
def test_postprocessing_pipeline_with_multi_context_imports() -> None:
"""Test the full testgen pipeline including all processing stages.
This tests the complete flow that generated test code goes through:
1. validate_testgen_code - validates and cleans raw LLM output
2. postprocessing_testgen_pipeline - transforms the code (includes all stages)
This also tests the fix for the bug where CST tree corruption could cause
'NoneType' object has no attribute 'visit' errors.
"""
# Raw test code (as if from LLM output)
raw_test_code = """
import pytest
from unittest.mock import Mock, MagicMock, patch
from typing import Any, Optional
import numpy as np
from unstructured.documents.elements import Text, ListItem, PageBreak
from unstructured.partition.pdf import document_to_element_list
def test_document_to_element_list_empty_document():
\"\"\"Test that an empty document returns an empty list of elements.\"\"\"
mock_document = Mock()
mock_document.pages = []
result = document_to_element_list(mock_document)
assert result == []
assert isinstance(result, list)
def test_document_to_element_list_single_element():
\"\"\"Test basic conversion with one element.\"\"\"
mock_layout_element = Mock()
mock_layout_element.text = "Hello World"
mock_layout_element.type = "Text"
mock_layout_element.bbox = Mock()
mock_layout_element.bbox.x1 = np.nan
mock_layout_element.parent = None
mock_page = Mock()
mock_page.elements_array = Mock()
mock_page.elements_array.element_class_id_map = {}
mock_page.elements_array.element_class_ids = np.array([])
mock_page.elements_array.iter_elements = Mock(return_value=[mock_layout_element])
mock_document = Mock()
mock_document.pages = [mock_page]
with patch('unstructured.partition.pdf.normalize_layout_element') as mock_normalize:
text_element = Text(text="Hello World")
mock_normalize.return_value = text_element
result = document_to_element_list(mock_document)
assert len(result) == 1
assert isinstance(result[0], Text)
"""
# Source code being tested (what the LLM saw)
source_code = """
def document_to_element_list(document, sortable=False, include_page_breaks=False,
last_modification_date=None, detection_origin=None):
\"\"\"Convert a document to a list of elements.\"\"\"
elements = []
for page in document.pages:
for layout_element in page.elements_array.iter_elements():
element = normalize_layout_element(layout_element)
elements.append(element)
return elements
"""
# Source code blocks simulating multi-context
source_code_blocks = {
"unstructured/partition/pdf.py": source_code,
"unstructured/documents/elements.py": """
class Element:
def __init__(self, text=""):
self.text = text
class Text(Element):
pass
class ListItem(Element):
pass
class PageBreak(Element):
pass
""",
}
expected = """from unstructured.partition.pdf import document_to_element_list
from typing import Any, Optional
from unittest.mock import MagicMock, Mock, patch
import numpy as np
import pytest
from unstructured.documents.elements import ListItem, PageBreak, Text
def test_document_to_element_list_empty_document():
\"\"\"Test that an empty document returns an empty list of elements.\"\"\"
mock_document = Mock()
mock_document.pages = []
codeflash_output = document_to_element_list(mock_document); result = codeflash_output
def test_document_to_element_list_single_element():
\"\"\"Test basic conversion with one element.\"\"\"
mock_layout_element = Mock()
mock_layout_element.text = "Hello World"
mock_layout_element.type = "Text"
mock_layout_element.bbox = Mock()
mock_layout_element.bbox.x1 = np.nan
mock_layout_element.parent = None
mock_page = Mock()
mock_page.elements_array = Mock()
mock_page.elements_array.element_class_id_map = {}
mock_page.elements_array.element_class_ids = np.array([])
mock_page.elements_array.iter_elements = Mock(return_value=[mock_layout_element])
mock_document = Mock()
mock_document.pages = [mock_page]
with patch('unstructured.partition.pdf.normalize_layout_element') as mock_normalize:
text_element = Text(text="Hello World")
mock_normalize.return_value = text_element
codeflash_output = document_to_element_list(mock_document); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="document_to_element_list",
file_path="unstructured/partition/pdf.py",
parents=[],
starting_line=None,
ending_line=None,
)
module_path = "unstructured.partition.pdf"
python_version = (3, 11)
# Step 1: Validate testgen code (simulates what happens after LLM response)
validated_code = validate_testgen_code(raw_test_code, python_version)
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
source_code_being_tested = group_code(source_code_blocks)
processed_module = postprocessing_testgen_pipeline(
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
)
assert processed_module.code == expected
def test_postprocessing_pipeline_with_unstructured_test_code() -> None:
"""Test the full testgen pipeline with complex test code from unstructured codebase.
This tests the complete flow with helper functions and multiple test cases:
1. validate_testgen_code - validates and cleans raw LLM output
2. postprocessing_testgen_pipeline - transforms the code (includes all stages)
Uses a simplified version of actual generated test code that triggered
'NoneType' object has no attribute 'visit' error.
"""
# Raw test code with helper functions (as if from LLM output)
raw_test_code = """
import math
from types import SimpleNamespace
import numpy as np
import pytest
from unstructured.partition.pdf import document_to_element_list
from unstructured.documents.elements import Text, ListItem, PageBreak, Title, Element
from unstructured.documents.elements import ElementType
def make_layout_element_dict_like(
*,
text: str = "",
element_type: str | None = None,
coordinates: tuple | None = None,
prob: float | None = None,
bbox_x1: float = 0.0,
bbox_x2: float = 1.0,
bbox_y1: float = 0.0,
bbox_y2: float = 1.0,
parent: object | None = None,
):
bbox = SimpleNamespace(x1=bbox_x1, x2=bbox_x2, y1=bbox_y1, y2=bbox_y2)
def to_dict():
out = {"text": text}
if element_type is not None:
out["type"] = element_type
if coordinates is not None:
out["coordinates"] = coordinates
if prob is not None:
out["prob"] = prob
return out
le = SimpleNamespace(
bbox=bbox,
to_dict=to_dict,
parent=parent,
)
return le
def make_page(elements, *, image_metadata=None, image=None):
class ElementsArray:
def __init__(self, elements):
self._elements = elements
self.element_class_id_map = {}
self.element_class_ids = np.array([], dtype=int)
def iter_elements(self):
for el in self._elements:
yield el
elements_array = ElementsArray(elements)
page = SimpleNamespace(
elements_array=elements_array,
image_metadata=image_metadata,
image=image,
)
return page
def test_single_text_element_basic():
coords = ((0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0))
layout_el = make_layout_element_dict_like(
text="Hello World",
element_type=ElementType.TEXT,
coordinates=coords,
bbox_x1=5.0,
)
image_metadata = {"format": "PNG", "width": 100, "height": 200}
page = make_page([layout_el], image_metadata=image_metadata)
document = SimpleNamespace(pages=[page])
elements = document_to_element_list(
document,
sortable=False,
include_page_breaks=False,
last_modification_date="2022-01-01",
detection_origin="detected-by-model",
)
assert len(elements) == 1
el = elements[0]
assert isinstance(el, Text)
assert str(el) == "Hello World"
assert el.metadata.page_number == 1
def test_list_items_are_inferred():
list_text = "1. First item\\n2. Second item"
layout_el = make_layout_element_dict_like(
text=list_text,
element_type=ElementType.LIST,
coordinates=None,
bbox_x1=5.0,
)
page = make_page([layout_el], image_metadata={"format": "JPEG", "width": 10, "height": 20})
document = SimpleNamespace(pages=[page])
elements = document_to_element_list(
document,
infer_list_items=True,
last_modification_date="2023-07-07",
)
assert len(elements) == 2
for item in elements:
assert isinstance(item, ListItem)
"""
# Source code being tested (multi-context format)
source_code = """
from unstructured.documents.elements import Element, Text, ListItem, PageBreak
from unstructured.partition.common.common import normalize_layout_element
def document_to_element_list(document, sortable=False, include_page_breaks=False,
last_modification_date=None, detection_origin=None,
starting_page_number=1, infer_list_items=True, **kwargs):
elements = []
for page_idx, page in enumerate(document.pages):
page_number = starting_page_number + page_idx
for layout_element in page.elements_array.iter_elements():
element = normalize_layout_element(layout_element)
if isinstance(element, list):
elements.extend(element)
else:
elements.append(element)
if include_page_breaks and page_idx < len(document.pages):
elements.append(PageBreak())
return elements
"""
# Source code blocks from unstructured codebase
source_code_blocks = {
"unstructured/partition/pdf.py": source_code,
"unstructured/documents/elements.py": """
from typing import Any, Optional
class ElementMetadata:
def __init__(self):
self.page_number = None
self.parent_id = None
self.coordinates = None
self.last_modified = None
class Element:
def __init__(self, text=""):
self.text = text
self.metadata = ElementMetadata()
self.id = id(self)
def __str__(self):
return self.text
class Text(Element):
pass
class ListItem(Element):
pass
class PageBreak(Element):
pass
class Title(Element):
pass
class ElementType:
TEXT = "Text"
LIST = "List"
TITLE = "Title"
""",
}
expected = r"""from unstructured.partition.pdf import document_to_element_list
import math
from types import SimpleNamespace
import numpy as np
import pytest
from unstructured.documents.elements import (Element, ElementType, ListItem,
PageBreak, Text, Title)
def make_layout_element_dict_like(
*,
text: str = "",
element_type: str | None = None,
coordinates: tuple | None = None,
prob: float | None = None,
bbox_x1: float = 0.0,
bbox_x2: float = 1.0,
bbox_y1: float = 0.0,
bbox_y2: float = 1.0,
parent: object | None = None,
):
bbox = SimpleNamespace(x1=bbox_x1, x2=bbox_x2, y1=bbox_y1, y2=bbox_y2)
def to_dict():
out = {"text": text}
if element_type is not None:
out["type"] = element_type
if coordinates is not None:
out["coordinates"] = coordinates
if prob is not None:
out["prob"] = prob
return out
le = SimpleNamespace(
bbox=bbox,
to_dict=to_dict,
parent=parent,
)
return le
def make_page(elements, *, image_metadata=None, image=None):
class ElementsArray:
def __init__(self, elements):
self._elements = elements
self.element_class_id_map = {}
self.element_class_ids = np.array([], dtype=int)
def iter_elements(self):
for el in self._elements:
yield el
elements_array = ElementsArray(elements)
page = SimpleNamespace(
elements_array=elements_array,
image_metadata=image_metadata,
image=image,
)
return page
def test_single_text_element_basic():
coords = ((0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0))
layout_el = make_layout_element_dict_like(
text="Hello World",
element_type=ElementType.TEXT,
coordinates=coords,
bbox_x1=5.0,
)
image_metadata = {"format": "PNG", "width": 100, "height": 200}
page = make_page([layout_el], image_metadata=image_metadata)
document = SimpleNamespace(pages=[page])
codeflash_output = document_to_element_list(
document,
sortable=False,
include_page_breaks=False,
last_modification_date="2022-01-01",
detection_origin="detected-by-model",
); elements = codeflash_output
el = elements[0]
def test_list_items_are_inferred():
list_text = "1. First item\n2. Second item"
layout_el = make_layout_element_dict_like(
text=list_text,
element_type=ElementType.LIST,
coordinates=None,
bbox_x1=5.0,
)
page = make_page([layout_el], image_metadata={"format": "JPEG", "width": 10, "height": 20})
document = SimpleNamespace(pages=[page])
codeflash_output = document_to_element_list(
document,
infer_list_items=True,
last_modification_date="2023-07-07",
); elements = codeflash_output
for item in elements:
pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="document_to_element_list",
file_path="unstructured/partition/pdf.py",
parents=[],
starting_line=None,
ending_line=None,
)
module_path = "unstructured.partition.pdf"
python_version = (3, 11)
# Step 1: Validate testgen code (simulates what happens after LLM response)
validated_code = validate_testgen_code(raw_test_code, python_version)
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
source_code_being_tested = group_code(source_code_blocks)
processed_module = postprocessing_testgen_pipeline(
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
)
assert processed_module.code == expected

View file

@ -137,6 +137,7 @@ dependencies = [
{ name = "sentry-sdk", extra = ["django"] },
{ name = "stamina" },
{ name = "uvicorn" },
{ name = "wcwidth" },
]
[package.dev-dependencies]
@ -176,6 +177,7 @@ requires-dist = [
{ name = "sentry-sdk", extras = ["django"], specifier = ">=2.35.0" },
{ name = "stamina", specifier = ">=25.1.0" },
{ name = "uvicorn", specifier = ">=0.32.0,<0.33" },
{ name = "wcwidth", specifier = ">=0.2.15" },
]
[package.metadata.requires-dev]
@ -1882,11 +1884,11 @@ wheels = [
[[package]]
name = "wcwidth"
version = "0.2.14"
version = "0.5.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" }
sdist = { url = "https://files.pythonhosted.org/packages/64/6e/62daec357285b927e82263a81f3b4c1790215bc77c42530ce4a69d501a43/wcwidth-0.5.0.tar.gz", hash = "sha256:f89c103c949a693bf563377b2153082bf58e309919dfb7f27b04d862a0089333", size = 246585, upload-time = "2026-01-27T01:31:44.942Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" },
{ url = "https://files.pythonhosted.org/packages/f2/3e/45583b67c2ff08ad5a582d316fcb2f11d6cf0a50c7707ac09d212d25bc98/wcwidth-0.5.0-py3-none-any.whl", hash = "sha256:1efe1361b83b0ff7877b81ba57c8562c99cf812158b778988ce17ec061095695", size = 93772, upload-time = "2026-01-27T01:31:43.432Z" },
]
[[package]]

View file

@ -327,11 +327,6 @@ export async function suggestPrChanges(
}
}
// Check if the owner is roboflow
if (owner === "roboflow") {
logger.info(`Rejecting request for roboflow repository`, req)
throw unauthorized("Unauthorized for roboflow repositories")
}
// No approval required, proceed with PR suggestion
const result = await triggerSuggestPrChanges(
owner,

View file

@ -168,34 +168,6 @@ describe("Suggest PR Changes", () => {
consoleSpy.mockRestore()
})
it("should return 401 for roboflow repositories", async () => {
mockReq.body.owner = "roboflow"
mockDependencies.userNickname.mockResolvedValue("test-user")
const mockInstallationOctokit = {
rest: {
pulls: {
get: jest.fn() as any,
},
},
}
mockInstallationOctokit.rest.pulls.get.mockResolvedValue({
data: { head: { ref: "feature-branch" } },
})
mockDependencies.getInstallationOctokitByOwner.mockResolvedValue(mockInstallationOctokit)
mockDependencies.isUserCollaborator.mockResolvedValue(true)
mockDependencies.requiresApproval.mockReturnValue(false)
const consoleSpy = jest.spyOn(console, "log").mockImplementation(() => {})
await expect(
suggestPrChanges(mockReq as AuthorizedUserReq, mockRes as Response),
).rejects.toMatchObject({
message: expect.stringContaining("Unauthorized for roboflow repositories"),
})
expect(mockRes.status).not.toHaveBeenCalled()
consoleSpy.mockRestore()
})
})
describe("approval workflow", () => {