mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Merge branch 'main' into multi-language
This commit is contained in:
commit
764a3f8899
23 changed files with 1032 additions and 185 deletions
15
.github/workflows/codeflash-aiservice.yaml
vendored
15
.github/workflows/codeflash-aiservice.yaml
vendored
|
|
@ -89,12 +89,19 @@ jobs:
|
|||
- name: Install Project Dependencies
|
||||
run: |
|
||||
uv sync --refresh --isolated --active
|
||||
uv pip uninstall codeflash
|
||||
uv pip install pytest-asyncio
|
||||
uv pip install black
|
||||
|
||||
- name: Install Codeflash in separate venv
|
||||
working-directory: .
|
||||
run: |
|
||||
uv venv .codeflash-venv
|
||||
source .codeflash-venv/bin/activate
|
||||
uv pip install pytest-asyncio black
|
||||
uv pip install git+https://github.com/codeflash-ai/codeflash@main
|
||||
|
||||
- name: Run Codeflash to optimize code
|
||||
id: optimize_code
|
||||
working-directory: .
|
||||
run: |
|
||||
uv run --active codeflash --async --verbose
|
||||
source .codeflash-venv/bin/activate
|
||||
cd django/aiservice
|
||||
codeflash --async --verbose
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -162,6 +162,7 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.aider*
|
||||
.serena/
|
||||
/js/common/node_modules/
|
||||
/node_modules/
|
||||
*.xml
|
||||
|
|
|
|||
|
|
@ -156,7 +156,7 @@ def make_number_node(value: int) -> cst.Integer | cst.UnaryOperation:
|
|||
def evaluate_expression(node: cst.BaseExpression) -> int | None:
|
||||
"""Evaluate a CST expression node to get its integer value, or None if not evaluable."""
|
||||
if isinstance(node, cst.Integer):
|
||||
return int(node.value)
|
||||
return int(node.value, 0)
|
||||
if isinstance(node, cst.Float):
|
||||
return int(float(node.value))
|
||||
if isinstance(node, cst.BinaryOperation):
|
||||
|
|
|
|||
|
|
@ -70,7 +70,9 @@ def split_markdown_code(markdown: str, language: str = "python") -> dict[str, st
|
|||
matches = pattern.findall(markdown)
|
||||
result: dict[str, str] = {}
|
||||
for file_path, code in matches:
|
||||
result.setdefault(file_path.strip(), code)
|
||||
stripped_path = file_path.strip()
|
||||
if stripped_path not in result:
|
||||
result[stripped_path] = code
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -79,7 +81,18 @@ def extract_code_block_with_context(text: str, language: str = "python") -> tupl
|
|||
|
||||
Returns (before_text, code_content, after_text) or None if not found.
|
||||
Handles optional filepath syntax (e.g., ```python:filepath.py).
|
||||
|
||||
Prefers code blocks with filepath annotations over plain code blocks,
|
||||
as these are more likely to contain the actual code in multi-block responses
|
||||
where the LLM may include example or explanation blocks before the real code.
|
||||
"""
|
||||
# First try to find a code block with filepath (e.g., ```python:file.py)
|
||||
# These are more likely to contain the actual optimized code
|
||||
pattern_with_path = rf"(.*?)```{language}:[^\n]+(?:\n|\\n)(.*?)```(.*)"
|
||||
if match := re.match(pattern_with_path, text, re.DOTALL | re.MULTILINE):
|
||||
return match.group(1).strip(), match.group(2), match.group(3).strip()
|
||||
|
||||
# Fall back to any code block (with or without filepath)
|
||||
pattern = rf"(.*?)```{language}(?::[^\n]*)?(?:\n|\\n)(.*?)```(.*)"
|
||||
if match := re.match(pattern, text, re.DOTALL | re.MULTILINE):
|
||||
return match.group(1).strip(), match.group(2), match.group(3).strip()
|
||||
|
|
|
|||
|
|
@ -235,6 +235,14 @@ async def jit_rewrite(
|
|||
if len(jit_rewrite_response_items) == 0:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-jit-rewrite-no-optimizations-found")
|
||||
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
||||
logging.error(
|
||||
"Could not generate any optimizations (jit_rewrite). trace_id=%s, repo=%s/%s, n_candidates=%d, source_len=%d",
|
||||
data.trace_id,
|
||||
data.repo_owner,
|
||||
data.repo_name,
|
||||
data.n_candidates,
|
||||
len(data.source_code) if data.source_code else 0,
|
||||
)
|
||||
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
|
|
|
|||
|
|
@ -19,12 +19,7 @@ from optimizer.context_utils.constants import (
|
|||
LINE_PROF_CONTEXT_PROMPT,
|
||||
MARKDOWN_CONTEXT_PROMPT,
|
||||
)
|
||||
from optimizer.context_utils.context_helpers import (
|
||||
group_code,
|
||||
is_markdown_structure_changed,
|
||||
is_multi_context,
|
||||
split_markdown_code,
|
||||
)
|
||||
from optimizer.context_utils.context_helpers import group_code, is_multi_context, split_markdown_code
|
||||
from optimizer.diff_patches_utils.diff import SEARCH_AND_REPLACE_FORMAT_PROMPT, V4A_DIFF_FORMAT_PROMPT, DiffMethod
|
||||
from optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
|
||||
from optimizer.diff_patches_utils.v4a_diff import V4ADiff
|
||||
|
|
@ -206,6 +201,8 @@ class MultiOptimizerContext(BaseOptimizerContext):
|
|||
super().__init__(base_system_prompt, base_user_prompt, source_code)
|
||||
self.original_file_to_code = split_markdown_code(source_code)
|
||||
|
||||
self._original_files_set = set(self.original_file_to_code.keys())
|
||||
|
||||
def get_system_prompt(self, python_version_str: str) -> str:
|
||||
# critical instructions should be bullet points, short and direct, further instructions should be a full paragraph which is added after the main system prompt
|
||||
critical_instructions = ""
|
||||
|
|
@ -274,8 +271,8 @@ class MultiOptimizerContext(BaseOptimizerContext):
|
|||
if markdown_code == "":
|
||||
return None
|
||||
|
||||
original_code_files = self.original_file_to_code.items()
|
||||
extracted_code_files = split_markdown_code(markdown_code).items()
|
||||
original_code_files = dict(self.original_file_to_code.items())
|
||||
extracted_code_files = dict(split_markdown_code(markdown_code).items())
|
||||
|
||||
op_id = str(uuid.uuid4())
|
||||
new_post_processed_code: dict[str, str] = {}
|
||||
|
|
@ -285,26 +282,34 @@ class MultiOptimizerContext(BaseOptimizerContext):
|
|||
code=extracted.code, explanation=extracted.explanation
|
||||
)
|
||||
|
||||
if len(original_code_files) != len(extracted_code_files):
|
||||
# Check if any extracted files exist in the original context
|
||||
# LLMs may only return files they actually modified
|
||||
valid_extracted_files = {f: c for f, c in extracted_code_files.items() if f in original_code_files}
|
||||
if not valid_extracted_files:
|
||||
debug_log_sensitive_data(
|
||||
f"original files {original_code_files} are not the same as extracted files {extracted_code_files}"
|
||||
f"No matching files found. Original: {list(original_code_files.keys())}, "
|
||||
f"Extracted: {list(extracted_code_files.keys())}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Log when LLM returns fewer files (common with large contexts)
|
||||
if len(valid_extracted_files) < len(original_code_files):
|
||||
debug_log_sensitive_data(
|
||||
f"LLM returned {len(valid_extracted_files)} of {len(original_code_files)} files. "
|
||||
f"Processing returned files: {list(valid_extracted_files.keys())}"
|
||||
)
|
||||
|
||||
# optimization is ignored when **all** files in the context are rejected by postprocessing pipeline
|
||||
optimization_ignored = True
|
||||
|
||||
# post process each file
|
||||
for (original_file, original_code), (new_file, new_code) in zip(
|
||||
original_code_files, extracted_code_files, strict=True
|
||||
):
|
||||
if original_file != new_file:
|
||||
# shouldn't happen but just in case
|
||||
continue
|
||||
# post process each file that was returned by the LLM
|
||||
for original_file, new_code in valid_extracted_files.items():
|
||||
original_code = original_code_files[original_file]
|
||||
try:
|
||||
new_cst_module = parse_module_to_cst(new_code)
|
||||
|
||||
code_and_explanation = CodeExplanationAndID(
|
||||
cst_module=new_cst_module, id=f"{new_file}:post-processing", explanation=extracted.explanation
|
||||
cst_module=new_cst_module, id=f"{original_file}:post-processing", explanation=extracted.explanation
|
||||
)
|
||||
postprocessed_list: list[CodeExplanationAndID] = optimizations_postprocessing_pipeline(
|
||||
original_code, [code_and_explanation]
|
||||
|
|
@ -344,12 +349,18 @@ class MultiOptimizerContext(BaseOptimizerContext):
|
|||
def is_valid_code(self) -> bool:
|
||||
if self.extracted_code_and_expl is None:
|
||||
return False
|
||||
changed = is_markdown_structure_changed(self.source_code, self.extracted_code_and_expl.code)
|
||||
if changed:
|
||||
extracted_files = set(split_markdown_code(self.extracted_code_and_expl.code).keys())
|
||||
|
||||
# All extracted files must be in the original set (no new files allowed)
|
||||
invalid_files = extracted_files - self._original_files_set
|
||||
if invalid_files:
|
||||
debug_log_sensitive_data(
|
||||
f"code markdown have been changed:=======\n{self.source_code}\\=======\n{self.extracted_code_and_expl.code}=======\n"
|
||||
f"LLM returned files not in original context: {invalid_files}. "
|
||||
f"Original: {self._original_files_set}, Extracted: {extracted_files}"
|
||||
)
|
||||
return super().is_valid_code() and (not changed)
|
||||
return False
|
||||
|
||||
return super().is_valid_code()
|
||||
|
||||
def validate_and_parse_source_code(self, code: str, feature_version: tuple) -> None:
|
||||
final_code = code or self.source_code
|
||||
|
|
|
|||
|
|
@ -377,7 +377,14 @@ async def optimize_python(
|
|||
if len(optimization_response_items) == 0:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
|
||||
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
||||
logging.error(f"Could not generate any optimizations. trace_id={data.trace_id}")
|
||||
logging.error(
|
||||
"Could not generate any optimizations. trace_id=%s, repo=%s/%s, n_candidates=%d, source_len=%d",
|
||||
data.trace_id,
|
||||
data.repo_owner,
|
||||
data.repo_name,
|
||||
data.n_candidates,
|
||||
len(data.source_code) if data.source_code else 0,
|
||||
)
|
||||
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
|
|
|
|||
|
|
@ -660,7 +660,13 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
if len(optimization_response_items) == 0:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
|
||||
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
||||
logging.error(f"Could not generate any optimizations (line profiler). trace_id={data.trace_id}")
|
||||
logging.error(
|
||||
"Could not generate any optimizations (line profiler). trace_id=%s, n_candidates=%d, source_len=%d, has_line_profiler=%s",
|
||||
data.trace_id,
|
||||
data.n_candidates,
|
||||
len(data.source_code) if data.source_code else 0,
|
||||
bool(data.line_profiler_results),
|
||||
)
|
||||
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ dependencies = [
|
|||
"stamina>=25.1.0",
|
||||
"jedi>=0.19.2",
|
||||
"anthropic>=0.75.0",
|
||||
"wcwidth>=0.2.15",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, SimpleStatementLine, parse_module
|
||||
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, SimpleStatementLine
|
||||
|
||||
from aiservice.common.cst_utils import DepthTrackingMixin, build_module_path, parse_module_to_cst
|
||||
|
||||
|
|
@ -57,52 +57,56 @@ def replace_first_function_definition(module_str: str, function_name: str, stmt:
|
|||
return modified_module_cst.code
|
||||
|
||||
|
||||
def replace_definition_with_import(source_code: str, function: FunctionToOptimize, module_path: str) -> str: # noqa: D417
|
||||
"""Replace a function or class definition with an import statement in the given CST tree.
|
||||
class _ImportInserter(DepthTrackingMixin, CSTTransformer):
|
||||
"""Transformer that replaces a function or class definition with an import statement."""
|
||||
|
||||
def __init__(self, target_name: str, import_stmt: SimpleStatementLine) -> None:
|
||||
DepthTrackingMixin.__init__(self)
|
||||
CSTTransformer.__init__(self)
|
||||
self.target_name = target_name
|
||||
self.import_stmt = import_stmt
|
||||
self.replaced_count = 0
|
||||
|
||||
def visit_ClassDef(self, node: ClassDef) -> None: # noqa: ARG002
|
||||
self._visit_class()
|
||||
|
||||
def leave_ClassDef(
|
||||
self, original_node: ClassDef, updated_node: ClassDef
|
||||
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
|
||||
if original_node.name.value == self.target_name and self._class_depth == 1 and self.replaced_count == 0:
|
||||
self.replaced_count += 1
|
||||
self._leave_class()
|
||||
return self.import_stmt
|
||||
self._leave_class()
|
||||
return updated_node
|
||||
|
||||
def leave_FunctionDef(
|
||||
self, original_node: FunctionDef, updated_node: FunctionDef
|
||||
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
|
||||
if original_node.name.value == self.target_name and self._class_depth == 0 and self.replaced_count == 0:
|
||||
self.replaced_count += 1
|
||||
return self.import_stmt
|
||||
return updated_node
|
||||
|
||||
|
||||
def replace_definition_with_import(module: Module, function: FunctionToOptimize, module_path: str) -> Module:
|
||||
"""Replace a function or class definition with an import statement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- tree (CSTModule): The CST tree representing the module.
|
||||
- function (FunctionToOptimize): The function to replace with an import.
|
||||
- module_path (str): Full module path of the function or class being replaced
|
||||
module : Module
|
||||
The CST module representing the test code.
|
||||
function : FunctionToOptimize
|
||||
The function to replace with an import.
|
||||
module_path : str
|
||||
Full module path of the function or class being replaced.
|
||||
|
||||
Returns
|
||||
-------
|
||||
- CSTModule: The modified CST tree.
|
||||
Module
|
||||
The modified CST module.
|
||||
|
||||
"""
|
||||
|
||||
class ImportInserter(DepthTrackingMixin, CSTTransformer):
|
||||
def __init__(self, target_name: str, import_stmt: SimpleStatementLine) -> None:
|
||||
DepthTrackingMixin.__init__(self)
|
||||
CSTTransformer.__init__(self)
|
||||
self.target_name = target_name
|
||||
self.import_stmt = import_stmt
|
||||
self.replaced_count = 0
|
||||
|
||||
def visit_ClassDef(self, node: ClassDef) -> None: # noqa: ARG002
|
||||
self._visit_class()
|
||||
|
||||
def leave_ClassDef(
|
||||
self, original_node: ClassDef, updated_node: ClassDef
|
||||
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
|
||||
if original_node.name.value == self.target_name and self._class_depth == 1 and self.replaced_count == 0:
|
||||
self.replaced_count += 1
|
||||
self._leave_class()
|
||||
return self.import_stmt
|
||||
self._leave_class()
|
||||
return updated_node
|
||||
|
||||
def leave_FunctionDef(
|
||||
self, original_node: FunctionDef, updated_node: FunctionDef
|
||||
) -> BaseStatement | FlattenSentinel[BaseStatement] | RemovalSentinel:
|
||||
if original_node.name.value == self.target_name and self._class_depth == 0 and self.replaced_count == 0:
|
||||
self.replaced_count += 1
|
||||
return self.import_stmt
|
||||
return updated_node
|
||||
|
||||
tree = parse_module(source_code)
|
||||
|
||||
if not function.parents:
|
||||
new_import = SimpleStatementLine(
|
||||
body=[
|
||||
|
|
@ -111,8 +115,8 @@ def replace_definition_with_import(source_code: str, function: FunctionToOptimiz
|
|||
)
|
||||
]
|
||||
)
|
||||
transformer = ImportInserter(function.function_name, new_import)
|
||||
return tree.visit(transformer).code
|
||||
transformer = _ImportInserter(function.function_name, new_import)
|
||||
return module.visit(transformer)
|
||||
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
|
||||
new_import = SimpleStatementLine(
|
||||
body=[
|
||||
|
|
@ -122,6 +126,6 @@ def replace_definition_with_import(source_code: str, function: FunctionToOptimiz
|
|||
)
|
||||
]
|
||||
)
|
||||
transformer = ImportInserter(function.parents[0].name, new_import)
|
||||
return tree.visit(transformer).code
|
||||
return tree.code
|
||||
transformer = _ImportInserter(function.parents[0].name, new_import)
|
||||
return module.visit(transformer)
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -555,7 +555,15 @@ def add_missing_imports_from_multi_context_source(
|
|||
return AddImportsVisitor(context).transform_module(module)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
logging.warning("add_missing_imports_from_multi_context_source failed: %s", e)
|
||||
source_files = list(source_code_blocks.keys()) if source_code_blocks else []
|
||||
logging.warning(
|
||||
"add_missing_imports_from_multi_context_source failed: %s, "
|
||||
"default_module_path=%s, source_files=%s, module_is_none=%s",
|
||||
e,
|
||||
default_module_path,
|
||||
source_files,
|
||||
module is None,
|
||||
)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return module
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
from typing import TYPE_CHECKING, Any
|
||||
from __future__ import annotations
|
||||
|
||||
from libcst import Module
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from optimizer.context_utils.context_helpers import is_multi_context, split_markdown_code
|
||||
from testgen.instrumentation.edit_generated_test import replace_definition_with_import
|
||||
from testgen.postprocessing.add_missing_imports import (
|
||||
add_missing_imports_from_multi_context_source,
|
||||
add_missing_imports_from_source,
|
||||
)
|
||||
from testgen.postprocessing.range_modifier import modify_large_loops
|
||||
from testgen.postprocessing.remove_unused_definitions import remove_unused_definitions_from_pytest_file
|
||||
from testgen.postprocessing.removeassert_transformer import remove_asserts_from_test
|
||||
|
|
@ -11,27 +16,56 @@ from testgen.postprocessing.topdef_terminator import delete_top_def_nodes
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from libcst import Module
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
def postprocessing_testgen_pipeline(
|
||||
module: Module, helper_function_names: list[str], function_to_optimize: FunctionToOptimize, module_path: str
|
||||
module: Module,
|
||||
helper_function_names: list[str],
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
module_path: str,
|
||||
source_code_being_tested: str,
|
||||
) -> Module:
|
||||
"""Apply postprocessing functions in a pipeline, a list of (function, kwargs)."""
|
||||
"""Full postprocessing pipeline for generated test code.
|
||||
|
||||
Applies all CST transformations in sequence:
|
||||
1. Clean up definitions (remove helper functions, unused definitions)
|
||||
2. Modify constructs (large loops, tensors)
|
||||
3. Remove asserts
|
||||
4. Add missing imports
|
||||
5. Replace function definition with import
|
||||
"""
|
||||
add_imports_func, add_imports_kwargs = (
|
||||
(
|
||||
add_missing_imports_from_multi_context_source,
|
||||
{"source_code_blocks": split_markdown_code(source_code_being_tested), "default_module_path": module_path},
|
||||
)
|
||||
if is_multi_context(source_code_being_tested)
|
||||
else (add_missing_imports_from_source, {"source_code": source_code_being_tested, "module_path": module_path})
|
||||
)
|
||||
|
||||
pipeline: list[tuple[Callable[..., Module], dict[str, Any]]] = [
|
||||
(delete_top_def_nodes, {"deletable_list": helper_function_names}),
|
||||
(remove_unused_definitions_from_pytest_file, {}),
|
||||
(modify_large_loops, {}),
|
||||
(modify_tensors, {}),
|
||||
(
|
||||
remove_asserts_from_test, # should be the last function in the pipeline since it transforms the code from result = extract_input_variables(nodes) to codeflash_output = extract_input_variables(nodes); result = codeflash_output and raises an error, probably because of the semicolon
|
||||
remove_asserts_from_test,
|
||||
{
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"helper_function_names": helper_function_names,
|
||||
"module_path": module_path,
|
||||
},
|
||||
),
|
||||
(add_imports_func, add_imports_kwargs),
|
||||
(replace_definition_with_import, {"function": function_to_optimize, "module_path": module_path}),
|
||||
]
|
||||
|
||||
for func, kwargs in pipeline:
|
||||
module = func(module, **kwargs)
|
||||
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from libcst import (
|
|||
SimpleStatementSuite,
|
||||
UnaryOperation,
|
||||
parse_expression,
|
||||
parse_module,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -44,7 +45,11 @@ class StatementHandler:
|
|||
self, updated_node: SimpleStatementLine | SimpleStatementSuite
|
||||
) -> SimpleStatementLine | SimpleStatementSuite | RemovalSentinel | FlattenSentinel[BaseSmallStatement] | None:
|
||||
if not updated_node.body:
|
||||
return updated_node
|
||||
# Empty body is malformed - handle it properly
|
||||
if isinstance(updated_node, SimpleStatementSuite):
|
||||
return updated_node.with_changes(body=[Pass()])
|
||||
# SimpleStatementLine with empty body should be removed
|
||||
return RemoveFromParent()
|
||||
|
||||
if len(updated_node.body) == 1:
|
||||
return self._handle_single_statement(updated_node)
|
||||
|
|
@ -90,8 +95,12 @@ class StatementHandler:
|
|||
if new_body and hasattr(new_body[-1], "semicolon"):
|
||||
new_body[-1] = new_body[-1].with_changes(semicolon=None)
|
||||
|
||||
if isinstance(updated_node, SimpleStatementSuite) and not new_body:
|
||||
new_body = [Pass()]
|
||||
if not new_body:
|
||||
if isinstance(updated_node, SimpleStatementSuite):
|
||||
new_body = [Pass()]
|
||||
else:
|
||||
# SimpleStatementLine with empty body is malformed - remove it entirely
|
||||
return RemoveFromParent()
|
||||
|
||||
return updated_node.with_changes(body=new_body)
|
||||
|
||||
|
|
@ -373,8 +382,27 @@ class RemoveAssertTransformer(CSTTransformer):
|
|||
) -> Any: # noqa: ANN401
|
||||
return self.statement_handler.handle_statement(updated_node)
|
||||
|
||||
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> RemovalSentinel | FunctionDef: # noqa: ARG002
|
||||
return updated_node if updated_node.body.body else RemoveFromParent()
|
||||
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> RemovalSentinel | FunctionDef:
|
||||
body_block = updated_node.body
|
||||
|
||||
# For IndentedBlock (regular multi-line functions)
|
||||
if hasattr(body_block, "body") and isinstance(body_block.body, (list, tuple)):
|
||||
body = body_block.body
|
||||
if not body:
|
||||
return RemoveFromParent()
|
||||
|
||||
# Only remove test functions that only contain pass (useless test stubs)
|
||||
# Non-test functions with pass are legitimate and should be kept
|
||||
if original_node.name.value.startswith("test_") and len(body) == 1:
|
||||
first = body[0]
|
||||
# Check if it's a SimpleStatementLine with just Pass()
|
||||
if isinstance(first, SimpleStatementLine) and len(first.body) == 1 and isinstance(first.body[0], Pass):
|
||||
return RemoveFromParent()
|
||||
# Check if it's a SimpleStatementSuite (one-liner) with just Pass()
|
||||
if isinstance(first, Pass):
|
||||
return RemoveFromParent()
|
||||
|
||||
return updated_node
|
||||
|
||||
def is_target_function_node(self, node: Call) -> bool:
|
||||
return (isinstance(node.func, Name) and node.func.value == self.only_function_name) or (
|
||||
|
|
@ -463,4 +491,7 @@ def remove_asserts_from_test(
|
|||
]
|
||||
)
|
||||
new_body = [add_import, *list(modified_tree.body)]
|
||||
return modified_tree.with_changes(body=tuple(new_body))
|
||||
result = modified_tree.with_changes(body=tuple(new_body))
|
||||
# Re-parse to ensure valid CST structure. The transformer may create nodes with
|
||||
# semicolon=None instead of MaybeSentinel, which causes errors during subsequent visits.
|
||||
return parse_module(result.code)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ from aiservice.models.functions_to_optimize import FunctionToOptimize
|
|||
from authapp.auth import AuthenticatedRequest
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from optimizer.context_utils.context_helpers import is_multi_context, split_markdown_code
|
||||
from libcst import parse_module
|
||||
|
||||
from testgen.instrumentation.edit_generated_test import replace_definition_with_import
|
||||
from testgen.instrumentation.instrument_new_tests import instrument_test_source
|
||||
from testgen.models import (
|
||||
|
|
@ -34,10 +35,6 @@ from testgen.models import (
|
|||
TestGenSchema,
|
||||
TestingMode,
|
||||
)
|
||||
from testgen.postprocessing.add_missing_imports import (
|
||||
add_missing_imports_from_multi_context_source,
|
||||
add_missing_imports_from_source,
|
||||
)
|
||||
from testgen.postprocessing.code_validator import CodeValidationError, has_test_functions, validate_testgen_code
|
||||
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
|
||||
from testgen.testgen_context import BaseTestGenContext, TestGenContextData
|
||||
|
|
@ -205,7 +202,7 @@ def parse_and_validate_llm_output(
|
|||
|
||||
if function_to_optimize is not None and module_path is not None:
|
||||
try:
|
||||
code = replace_definition_with_import(code, function_to_optimize, module_path)
|
||||
code = replace_definition_with_import(parse_module(code), function_to_optimize, module_path).code
|
||||
except Exception: # noqa: BLE001
|
||||
# If replacement fails (e.g., parsing error), continue with original code
|
||||
logging.warning("replace_definition_with_import failed, continuing with original code")
|
||||
|
|
@ -330,24 +327,10 @@ async def generate_regression_tests_from_function(
|
|||
data.helper_function_names or [],
|
||||
data.function_to_optimize,
|
||||
data.module_path,
|
||||
data.source_code_being_tested,
|
||||
)
|
||||
|
||||
# Add missing imports for symbols defined in source module but not imported in test.
|
||||
# This handles cases where the LLM redefines some classes locally but forgets others.
|
||||
if is_multi_context(data.source_code_being_tested):
|
||||
# Multi-context: use per-file module mapping for correct imports
|
||||
source_code_blocks = split_markdown_code(data.source_code_being_tested)
|
||||
processed_cst = add_missing_imports_from_multi_context_source(
|
||||
processed_cst, source_code_blocks, data.module_path
|
||||
)
|
||||
else:
|
||||
processed_cst = add_missing_imports_from_source(
|
||||
processed_cst, data.source_code_being_tested, data.module_path
|
||||
)
|
||||
|
||||
generated_test_source = replace_definition_with_import(
|
||||
processed_cst.code, data.function_to_optimize, data.module_path
|
||||
)
|
||||
generated_test_source = processed_cst.code
|
||||
|
||||
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
|
||||
generated_test_source, data, python_version
|
||||
|
|
|
|||
|
|
@ -50,6 +50,27 @@ def test_evaluate_expression_integer() -> None:
|
|||
assert evaluate_expression(node) == 42
|
||||
|
||||
|
||||
def test_evaluate_expression_hexadecimal() -> None:
|
||||
"""Test that hexadecimal integer literals (0x prefix) are correctly evaluated."""
|
||||
assert evaluate_expression(cst.Integer("0x00")) == 0
|
||||
assert evaluate_expression(cst.Integer("0xFF")) == 255
|
||||
assert evaluate_expression(cst.Integer("0xDEADBEEF")) == 0xDEADBEEF
|
||||
|
||||
|
||||
def test_evaluate_expression_octal() -> None:
|
||||
"""Test that octal integer literals (0o prefix) are correctly evaluated."""
|
||||
assert evaluate_expression(cst.Integer("0o0")) == 0
|
||||
assert evaluate_expression(cst.Integer("0o77")) == 63
|
||||
assert evaluate_expression(cst.Integer("0o755")) == 493
|
||||
|
||||
|
||||
def test_evaluate_expression_binary() -> None:
|
||||
"""Test that binary integer literals (0b prefix) are correctly evaluated."""
|
||||
assert evaluate_expression(cst.Integer("0b0")) == 0
|
||||
assert evaluate_expression(cst.Integer("0b1010")) == 10
|
||||
assert evaluate_expression(cst.Integer("0b11111111")) == 255
|
||||
|
||||
|
||||
def test_evaluate_expression_float() -> None:
|
||||
node = cst.Float("3.14")
|
||||
assert evaluate_expression(node) == 3
|
||||
|
|
|
|||
|
|
@ -248,3 +248,86 @@ def test_wrap_and_extract_roundtrip() -> None:
|
|||
wrapped = wrap_code_in_markdown(original_code)
|
||||
extracted = extract_code_block(wrapped)
|
||||
assert extracted == original_code
|
||||
|
||||
|
||||
def test_extract_code_block_with_context_prefers_filepath_block() -> None:
|
||||
"""Test that filepath-annotated blocks are preferred over plain blocks.
|
||||
|
||||
This handles LLM responses that include explanation code blocks before
|
||||
the actual optimized code with a filepath annotation. This test is based on
|
||||
an actual failing JIT rewrite response (trace 59be9279-dd59-4fad-aaf4-6bc541e44cb8)
|
||||
where the LLM returned an explanation with example code before the real optimized code.
|
||||
"""
|
||||
# Actual response pattern from gpt-5-mini JIT rewrite that caused extraction failure
|
||||
text = """**Optimization Explanation:**
|
||||
I applied targeted micro-optimizations that reduce attribute lookups.
|
||||
|
||||
```python
|
||||
# Example of the original slow pattern:
|
||||
def slow_convert(sequence):
|
||||
result = []
|
||||
for item in sequence:
|
||||
result.append(tuple(item))
|
||||
return tuple(result)
|
||||
```
|
||||
|
||||
The optimized code uses list comprehension for better performance:
|
||||
|
||||
```python:unstructured/documents/elements.py
|
||||
from __future__ import annotations
|
||||
import dataclasses as dc
|
||||
|
||||
@dc.dataclass
|
||||
class DataSourceMetadata:
|
||||
url: str | None = None
|
||||
```
|
||||
|
||||
These yield measurable runtime improvements."""
|
||||
result = extract_code_block_with_context(text)
|
||||
assert result is not None
|
||||
before, code, after = result
|
||||
# Should extract the filepath-annotated block, not the example block
|
||||
assert (
|
||||
before
|
||||
== """**Optimization Explanation:**
|
||||
I applied targeted micro-optimizations that reduce attribute lookups.
|
||||
|
||||
```python
|
||||
# Example of the original slow pattern:
|
||||
def slow_convert(sequence):
|
||||
result = []
|
||||
for item in sequence:
|
||||
result.append(tuple(item))
|
||||
return tuple(result)
|
||||
```
|
||||
|
||||
The optimized code uses list comprehension for better performance:"""
|
||||
)
|
||||
assert (
|
||||
code
|
||||
== """from __future__ import annotations
|
||||
import dataclasses as dc
|
||||
|
||||
@dc.dataclass
|
||||
class DataSourceMetadata:
|
||||
url: str | None = None
|
||||
"""
|
||||
)
|
||||
assert after == "These yield measurable runtime improvements."
|
||||
|
||||
|
||||
def test_extract_code_block_with_context_falls_back_to_plain_block() -> None:
|
||||
"""Test fallback to plain block when no filepath-annotated block exists."""
|
||||
text = """Here's the code:
|
||||
|
||||
```python
|
||||
def simple_function():
|
||||
return 42
|
||||
```
|
||||
"""
|
||||
result = extract_code_block_with_context(text)
|
||||
assert result is not None
|
||||
before, code, after = result
|
||||
assert before == "Here's the code:"
|
||||
assert code == "def simple_function():\n return 42\n"
|
||||
assert after == ""
|
||||
|
|
|
|||
|
|
@ -703,3 +703,126 @@ def gcd(a: int, b: int) -> int:
|
|||
assert ctx.extracted_code_and_expl.code == expected_code
|
||||
assert ctx.extracted_code_and_expl.explanation == "I replaced the recursive implementation with an iterative loop."
|
||||
assert ctx.is_valid_code()
|
||||
|
||||
|
||||
def test_multi_optimizer_partial_file_return() -> None:
|
||||
"""Test that LLM returning only modified files (not all original files) is handled correctly.
|
||||
|
||||
This tests the fix for trace_id=5e4306f4-a909-48b1-aeb4-932c1b451129 where the optimizer
|
||||
received 7 files but LLMs only returned 1-2 files they actually modified. The strict
|
||||
validation previously required exact file count matching, causing all candidates to fail.
|
||||
"""
|
||||
# Original code has 3 files
|
||||
original_code = """```python:app/utils.py
|
||||
def helper():
|
||||
return "helper"
|
||||
```
|
||||
```python:app/main.py
|
||||
from app.utils import helper
|
||||
|
||||
def process(data):
|
||||
for i in range(len(data)):
|
||||
data[i] = data[i] * 2
|
||||
return data
|
||||
```
|
||||
```python:app/constants.py
|
||||
MAX_SIZE = 100
|
||||
DEFAULT_VALUE = 0
|
||||
```
|
||||
"""
|
||||
# LLM only returns 1 file (the one it actually modified)
|
||||
llm_response = """I optimized the loop using list comprehension for better performance.
|
||||
|
||||
```python:app/main.py
|
||||
from app.utils import helper
|
||||
|
||||
def process(data):
|
||||
return [x * 2 for x in data]
|
||||
```
|
||||
"""
|
||||
ctx = create_optimizer_context(original_code)
|
||||
assert isinstance(ctx, MultiOptimizerContext)
|
||||
ctx.extract_code_and_explanation_from_llm_res(llm_response)
|
||||
|
||||
# Verify extraction worked correctly
|
||||
assert ctx.extracted_code_and_expl is not None
|
||||
assert (
|
||||
ctx.extracted_code_and_expl.code
|
||||
== """```python:app/main.py
|
||||
from app.utils import helper
|
||||
|
||||
def process(data):
|
||||
return [x * 2 for x in data]
|
||||
```"""
|
||||
)
|
||||
assert (
|
||||
ctx.extracted_code_and_expl.explanation
|
||||
== "I optimized the loop using list comprehension for better performance."
|
||||
)
|
||||
|
||||
# Should succeed - partial return is valid
|
||||
assert ctx.is_valid_code() is True
|
||||
|
||||
candidate = ctx.parse_and_generate_candidate_schema()
|
||||
assert candidate is not None
|
||||
assert candidate.optimization_id is not None
|
||||
|
||||
# The returned code should only contain the modified file (with post-processing applied)
|
||||
assert (
|
||||
candidate.source_code
|
||||
== """```python:app/main.py
|
||||
from app.utils import helper
|
||||
|
||||
|
||||
def process(data):
|
||||
return [x * 2 for x in data]
|
||||
```"""
|
||||
)
|
||||
assert candidate.explanation == "I optimized the loop using list comprehension for better performance."
|
||||
|
||||
|
||||
def test_multi_optimizer_rejects_new_files_not_in_original() -> None:
|
||||
"""Test that LLM returning files NOT in the original context is rejected.
|
||||
|
||||
While partial returns (subset of original files) are allowed, returning
|
||||
new files that weren't in the original context should be rejected.
|
||||
"""
|
||||
# Need at least 2 files to get a MultiOptimizerContext
|
||||
original_code = """```python:app/main.py
|
||||
def process(data):
|
||||
return data * 2
|
||||
```
|
||||
```python:app/utils.py
|
||||
def helper():
|
||||
return "helper"
|
||||
```
|
||||
"""
|
||||
# LLM returns a file that wasn't in the original context
|
||||
llm_response = """I added a new helper module.
|
||||
|
||||
```python:app/new_helpers.py
|
||||
def double(x):
|
||||
return x * 2
|
||||
```
|
||||
"""
|
||||
ctx = create_optimizer_context(original_code)
|
||||
assert isinstance(ctx, MultiOptimizerContext)
|
||||
ctx.extract_code_and_explanation_from_llm_res(llm_response)
|
||||
|
||||
# Verify extraction worked (the code was extracted, just not valid)
|
||||
assert ctx.extracted_code_and_expl is not None
|
||||
assert (
|
||||
ctx.extracted_code_and_expl.code
|
||||
== """```python:app/new_helpers.py
|
||||
def double(x):
|
||||
return x * 2
|
||||
```"""
|
||||
)
|
||||
assert ctx.extracted_code_and_expl.explanation == "I added a new helper module."
|
||||
|
||||
# Should fail - new file not in original context
|
||||
assert ctx.is_valid_code() is False
|
||||
|
||||
# parse_and_generate_candidate_schema should return None for invalid files
|
||||
candidate = ctx.parse_and_generate_candidate_schema()
|
||||
assert candidate is None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import ast
|
||||
|
||||
from libcst import parse_module
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from testgen.instrumentation.edit_generated_test import replace_definition_with_import
|
||||
|
||||
|
|
@ -24,7 +26,7 @@ def test_something():
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = """
|
||||
import pytest
|
||||
|
|
@ -68,7 +70,7 @@ def test_something():
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = '''
|
||||
import pytest
|
||||
|
|
@ -118,7 +120,7 @@ async def clone_agent_template(request, template_id):
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = '''
|
||||
import pytest
|
||||
|
|
@ -163,7 +165,7 @@ class OuterClass:
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = '''
|
||||
import pytest
|
||||
|
|
@ -207,7 +209,7 @@ def test_something():
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = """
|
||||
import pytest
|
||||
|
|
@ -241,7 +243,7 @@ class SomeClass:
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = """
|
||||
import pytest
|
||||
|
|
@ -286,7 +288,7 @@ class MockLogger:
|
|||
)
|
||||
module_path = "app.api.routes.agent"
|
||||
|
||||
result = replace_definition_with_import(source_code, function, module_path)
|
||||
result = replace_definition_with_import(parse_module(source_code), function, module_path).code
|
||||
|
||||
expected = """
|
||||
import pytest
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
from libcst import Pass, RemoveFromParent, SimpleStatementLine, SimpleStatementSuite
|
||||
from libcst import parse_module as parse_module_to_cst
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from testgen.postprocessing.removeassert_transformer import remove_asserts_from_test
|
||||
from testgen.postprocessing.removeassert_transformer import (
|
||||
RemoveAssertTransformer,
|
||||
StatementHandler,
|
||||
remove_asserts_from_test,
|
||||
)
|
||||
|
||||
|
||||
def test_remove_asserts() -> None:
|
||||
|
|
@ -796,3 +801,90 @@ class TestPigLatin(unittest.TestCase):
|
|||
module_path="code_to_optimize.bubble_sort",
|
||||
).code
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_remove_asserts_empty_simple_statement_line_removed() -> None:
|
||||
"""Test that a SimpleStatementLine with all statements removed is properly removed.
|
||||
|
||||
This tests the fix for the bug where an empty SimpleStatementLine would create
|
||||
a malformed CST that could cause 'NoneType' object has no attribute 'visit' errors.
|
||||
When all asserts on a line are irrelevant (not related to the function being optimized),
|
||||
the entire line should be removed rather than leaving an empty body.
|
||||
"""
|
||||
original_test = """def test_multiple_irrelevant_asserts():
|
||||
x = 5
|
||||
assert 1 == 1; assert True; assert 2 > 1
|
||||
y = 10
|
||||
"""
|
||||
|
||||
expected = """from some_file import some_function
|
||||
def test_multiple_irrelevant_asserts():
|
||||
x = 5
|
||||
y = 10
|
||||
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
|
||||
"""
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
|
||||
)
|
||||
result = remove_asserts_from_test(
|
||||
module=parse_module_to_cst(original_test),
|
||||
function_to_optimize=function_to_optimize,
|
||||
helper_function_names=[],
|
||||
module_path="some_file",
|
||||
).code
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_remove_asserts_one_line_function_all_removed() -> None:
|
||||
"""Test that one-line functions with only irrelevant asserts are removed entirely.
|
||||
|
||||
When a one-line function body (like `def test(): assert x`) has all its statements
|
||||
removed (because they're irrelevant asserts), the entire function should be removed
|
||||
rather than leaving a useless `def test(): pass` stub.
|
||||
"""
|
||||
original_test = """def test_irrelevant(): assert 1 == 1"""
|
||||
|
||||
expected = """from some_file import some_function
|
||||
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
|
||||
)
|
||||
result = remove_asserts_from_test(
|
||||
module=parse_module_to_cst(original_test),
|
||||
function_to_optimize=function_to_optimize,
|
||||
helper_function_names=[],
|
||||
module_path="some_file",
|
||||
).code
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_handle_statement_with_empty_body() -> None:
|
||||
"""Test that handle_statement properly handles nodes with empty body.
|
||||
|
||||
This tests the fix for the bug where a SimpleStatementLine with an empty body
|
||||
(from prior transformations or malformed input) would be returned unchanged,
|
||||
causing CST tree corruption and 'NoneType' object has no attribute 'visit' errors
|
||||
in subsequent transforms.
|
||||
|
||||
- SimpleStatementLine with empty body should return RemoveFromParent()
|
||||
- SimpleStatementSuite with empty body should return a node with Pass()
|
||||
"""
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
|
||||
)
|
||||
transformer = RemoveAssertTransformer(function_to_optimize, [])
|
||||
handler = StatementHandler(transformer)
|
||||
|
||||
# Test SimpleStatementLine with empty body returns RemoveFromParent()
|
||||
empty_line = SimpleStatementLine(body=[])
|
||||
result = handler.handle_statement(empty_line)
|
||||
assert result == RemoveFromParent()
|
||||
|
||||
# Test SimpleStatementSuite with empty body returns node with Pass()
|
||||
empty_suite = SimpleStatementSuite(body=[])
|
||||
result = handler.handle_statement(empty_suite)
|
||||
assert isinstance(result, SimpleStatementSuite)
|
||||
assert len(result.body) == 1
|
||||
assert isinstance(result.body[0], Pass)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from libcst import parse_module as parse_module_to_cst
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from optimizer.context_utils.context_helpers import group_code
|
||||
from testgen.postprocessing.code_validator import validate_testgen_code
|
||||
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
|
||||
|
||||
|
||||
|
|
@ -67,6 +69,13 @@ def test_few_shot_variables():
|
|||
'''
|
||||
module = parse_module_to_cst(code)
|
||||
|
||||
# Source code that defines the function being tested
|
||||
source_code_being_tested = '''
|
||||
def extract_input_variables(nodes):
|
||||
"""Extracts input variables from the template."""
|
||||
pass
|
||||
'''
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="extract_input_variables",
|
||||
file_path="testgen/postprocessing/tests/test_validate_pipeline.py",
|
||||
|
|
@ -76,44 +85,19 @@ def test_few_shot_variables():
|
|||
)
|
||||
module_path = "test_validate_pipeline"
|
||||
result = postprocessing_testgen_pipeline(
|
||||
module, ["function_to_remove"], function_to_optimize, module_path
|
||||
) # function_to_remove is in the deletable_list, execute delete_top_def_nodes in the pipeline
|
||||
module, ["function_to_remove"], function_to_optimize, module_path, source_code_being_tested
|
||||
)
|
||||
|
||||
expected = r'''from test_validate_pipeline import extract_input_variables
|
||||
# After consolidation, the function definition is removed by add_missing_imports_from_source
|
||||
# (which detects local redefinitions of public symbols) and replaced with an import.
|
||||
# The import is added at the top by AddImportsVisitor.
|
||||
expected = r"""from test_validate_pipeline import extract_input_variables
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
# imports
|
||||
import pytest # used for our unit tests
|
||||
|
||||
def extract_input_variables(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Extracts input variables from the template and adds them to the input_variables field."""
|
||||
prompt_pattern = re.compile(r"\{(.*?)\}")
|
||||
|
||||
for node in nodes:
|
||||
try:
|
||||
data_node = node["data"]["node"]
|
||||
template_info = data_node["template"]
|
||||
template_type = template_info["_type"]
|
||||
|
||||
if "input_variables" in template_info:
|
||||
if template_type == "prompt":
|
||||
value = template_info["template"]["value"]
|
||||
variables = prompt_pattern.findall(value)
|
||||
elif template_type == "few_shot":
|
||||
prefix = template_info["prefix"]["value"]
|
||||
suffix = template_info["suffix"]["value"]
|
||||
variables = prompt_pattern.findall(prefix + suffix)
|
||||
else:
|
||||
variables = []
|
||||
|
||||
template_info["input_variables"]["value"] = variables
|
||||
except (KeyError, TypeError):
|
||||
# Exception suppressed as in the original code
|
||||
pass
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
# unit tests
|
||||
|
||||
|
|
@ -129,6 +113,465 @@ def test_few_shot_variables():
|
|||
nodes = [{"data": {"node": {"template": {"_type": "few_shot", "prefix": {"value": "{var1}"}, "suffix": {"value": "{var2}"}, "input_variables": {}}}}}]
|
||||
codeflash_output = extract_input_variables(nodes); result = codeflash_output
|
||||
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
|
||||
'''
|
||||
"""
|
||||
|
||||
assert result.code == expected
|
||||
|
||||
|
||||
def test_postprocessing_pipeline_with_multi_context_imports() -> None:
|
||||
"""Test the full testgen pipeline including all processing stages.
|
||||
|
||||
This tests the complete flow that generated test code goes through:
|
||||
1. validate_testgen_code - validates and cleans raw LLM output
|
||||
2. postprocessing_testgen_pipeline - transforms the code (includes all stages)
|
||||
|
||||
This also tests the fix for the bug where CST tree corruption could cause
|
||||
'NoneType' object has no attribute 'visit' errors.
|
||||
"""
|
||||
# Raw test code (as if from LLM output)
|
||||
raw_test_code = """
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from typing import Any, Optional
|
||||
import numpy as np
|
||||
|
||||
from unstructured.documents.elements import Text, ListItem, PageBreak
|
||||
from unstructured.partition.pdf import document_to_element_list
|
||||
|
||||
|
||||
def test_document_to_element_list_empty_document():
|
||||
\"\"\"Test that an empty document returns an empty list of elements.\"\"\"
|
||||
mock_document = Mock()
|
||||
mock_document.pages = []
|
||||
result = document_to_element_list(mock_document)
|
||||
assert result == []
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
def test_document_to_element_list_single_element():
|
||||
\"\"\"Test basic conversion with one element.\"\"\"
|
||||
mock_layout_element = Mock()
|
||||
mock_layout_element.text = "Hello World"
|
||||
mock_layout_element.type = "Text"
|
||||
mock_layout_element.bbox = Mock()
|
||||
mock_layout_element.bbox.x1 = np.nan
|
||||
mock_layout_element.parent = None
|
||||
|
||||
mock_page = Mock()
|
||||
mock_page.elements_array = Mock()
|
||||
mock_page.elements_array.element_class_id_map = {}
|
||||
mock_page.elements_array.element_class_ids = np.array([])
|
||||
mock_page.elements_array.iter_elements = Mock(return_value=[mock_layout_element])
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.pages = [mock_page]
|
||||
|
||||
with patch('unstructured.partition.pdf.normalize_layout_element') as mock_normalize:
|
||||
text_element = Text(text="Hello World")
|
||||
mock_normalize.return_value = text_element
|
||||
result = document_to_element_list(mock_document)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Text)
|
||||
"""
|
||||
|
||||
# Source code being tested (what the LLM saw)
|
||||
source_code = """
|
||||
def document_to_element_list(document, sortable=False, include_page_breaks=False,
|
||||
last_modification_date=None, detection_origin=None):
|
||||
\"\"\"Convert a document to a list of elements.\"\"\"
|
||||
elements = []
|
||||
for page in document.pages:
|
||||
for layout_element in page.elements_array.iter_elements():
|
||||
element = normalize_layout_element(layout_element)
|
||||
elements.append(element)
|
||||
return elements
|
||||
"""
|
||||
|
||||
# Source code blocks simulating multi-context
|
||||
source_code_blocks = {
|
||||
"unstructured/partition/pdf.py": source_code,
|
||||
"unstructured/documents/elements.py": """
|
||||
class Element:
|
||||
def __init__(self, text=""):
|
||||
self.text = text
|
||||
|
||||
class Text(Element):
|
||||
pass
|
||||
|
||||
class ListItem(Element):
|
||||
pass
|
||||
|
||||
class PageBreak(Element):
|
||||
pass
|
||||
""",
|
||||
}
|
||||
|
||||
expected = """from unstructured.partition.pdf import document_to_element_list
|
||||
from typing import Any, Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unstructured.documents.elements import ListItem, PageBreak, Text
|
||||
|
||||
|
||||
def test_document_to_element_list_empty_document():
|
||||
\"\"\"Test that an empty document returns an empty list of elements.\"\"\"
|
||||
mock_document = Mock()
|
||||
mock_document.pages = []
|
||||
codeflash_output = document_to_element_list(mock_document); result = codeflash_output
|
||||
|
||||
|
||||
def test_document_to_element_list_single_element():
|
||||
\"\"\"Test basic conversion with one element.\"\"\"
|
||||
mock_layout_element = Mock()
|
||||
mock_layout_element.text = "Hello World"
|
||||
mock_layout_element.type = "Text"
|
||||
mock_layout_element.bbox = Mock()
|
||||
mock_layout_element.bbox.x1 = np.nan
|
||||
mock_layout_element.parent = None
|
||||
|
||||
mock_page = Mock()
|
||||
mock_page.elements_array = Mock()
|
||||
mock_page.elements_array.element_class_id_map = {}
|
||||
mock_page.elements_array.element_class_ids = np.array([])
|
||||
mock_page.elements_array.iter_elements = Mock(return_value=[mock_layout_element])
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.pages = [mock_page]
|
||||
|
||||
with patch('unstructured.partition.pdf.normalize_layout_element') as mock_normalize:
|
||||
text_element = Text(text="Hello World")
|
||||
mock_normalize.return_value = text_element
|
||||
codeflash_output = document_to_element_list(mock_document); result = codeflash_output
|
||||
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="document_to_element_list",
|
||||
file_path="unstructured/partition/pdf.py",
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
module_path = "unstructured.partition.pdf"
|
||||
python_version = (3, 11)
|
||||
|
||||
# Step 1: Validate testgen code (simulates what happens after LLM response)
|
||||
validated_code = validate_testgen_code(raw_test_code, python_version)
|
||||
|
||||
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
|
||||
source_code_being_tested = group_code(source_code_blocks)
|
||||
processed_module = postprocessing_testgen_pipeline(
|
||||
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
|
||||
)
|
||||
|
||||
assert processed_module.code == expected
|
||||
|
||||
|
||||
def test_postprocessing_pipeline_with_unstructured_test_code() -> None:
|
||||
"""Test the full testgen pipeline with complex test code from unstructured codebase.
|
||||
|
||||
This tests the complete flow with helper functions and multiple test cases:
|
||||
1. validate_testgen_code - validates and cleans raw LLM output
|
||||
2. postprocessing_testgen_pipeline - transforms the code (includes all stages)
|
||||
|
||||
Uses a simplified version of actual generated test code that triggered
|
||||
'NoneType' object has no attribute 'visit' error.
|
||||
"""
|
||||
# Raw test code with helper functions (as if from LLM output)
|
||||
raw_test_code = """
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from unstructured.partition.pdf import document_to_element_list
|
||||
from unstructured.documents.elements import Text, ListItem, PageBreak, Title, Element
|
||||
from unstructured.documents.elements import ElementType
|
||||
|
||||
|
||||
def make_layout_element_dict_like(
|
||||
*,
|
||||
text: str = "",
|
||||
element_type: str | None = None,
|
||||
coordinates: tuple | None = None,
|
||||
prob: float | None = None,
|
||||
bbox_x1: float = 0.0,
|
||||
bbox_x2: float = 1.0,
|
||||
bbox_y1: float = 0.0,
|
||||
bbox_y2: float = 1.0,
|
||||
parent: object | None = None,
|
||||
):
|
||||
bbox = SimpleNamespace(x1=bbox_x1, x2=bbox_x2, y1=bbox_y1, y2=bbox_y2)
|
||||
def to_dict():
|
||||
out = {"text": text}
|
||||
if element_type is not None:
|
||||
out["type"] = element_type
|
||||
if coordinates is not None:
|
||||
out["coordinates"] = coordinates
|
||||
if prob is not None:
|
||||
out["prob"] = prob
|
||||
return out
|
||||
le = SimpleNamespace(
|
||||
bbox=bbox,
|
||||
to_dict=to_dict,
|
||||
parent=parent,
|
||||
)
|
||||
return le
|
||||
|
||||
|
||||
def make_page(elements, *, image_metadata=None, image=None):
|
||||
class ElementsArray:
|
||||
def __init__(self, elements):
|
||||
self._elements = elements
|
||||
self.element_class_id_map = {}
|
||||
self.element_class_ids = np.array([], dtype=int)
|
||||
|
||||
def iter_elements(self):
|
||||
for el in self._elements:
|
||||
yield el
|
||||
|
||||
elements_array = ElementsArray(elements)
|
||||
page = SimpleNamespace(
|
||||
elements_array=elements_array,
|
||||
image_metadata=image_metadata,
|
||||
image=image,
|
||||
)
|
||||
return page
|
||||
|
||||
|
||||
def test_single_text_element_basic():
|
||||
coords = ((0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0))
|
||||
layout_el = make_layout_element_dict_like(
|
||||
text="Hello World",
|
||||
element_type=ElementType.TEXT,
|
||||
coordinates=coords,
|
||||
bbox_x1=5.0,
|
||||
)
|
||||
image_metadata = {"format": "PNG", "width": 100, "height": 200}
|
||||
page = make_page([layout_el], image_metadata=image_metadata)
|
||||
document = SimpleNamespace(pages=[page])
|
||||
|
||||
elements = document_to_element_list(
|
||||
document,
|
||||
sortable=False,
|
||||
include_page_breaks=False,
|
||||
last_modification_date="2022-01-01",
|
||||
detection_origin="detected-by-model",
|
||||
)
|
||||
|
||||
assert len(elements) == 1
|
||||
el = elements[0]
|
||||
assert isinstance(el, Text)
|
||||
assert str(el) == "Hello World"
|
||||
assert el.metadata.page_number == 1
|
||||
|
||||
|
||||
def test_list_items_are_inferred():
|
||||
list_text = "1. First item\\n2. Second item"
|
||||
layout_el = make_layout_element_dict_like(
|
||||
text=list_text,
|
||||
element_type=ElementType.LIST,
|
||||
coordinates=None,
|
||||
bbox_x1=5.0,
|
||||
)
|
||||
page = make_page([layout_el], image_metadata={"format": "JPEG", "width": 10, "height": 20})
|
||||
document = SimpleNamespace(pages=[page])
|
||||
|
||||
elements = document_to_element_list(
|
||||
document,
|
||||
infer_list_items=True,
|
||||
last_modification_date="2023-07-07",
|
||||
)
|
||||
|
||||
assert len(elements) == 2
|
||||
for item in elements:
|
||||
assert isinstance(item, ListItem)
|
||||
"""
|
||||
|
||||
# Source code being tested (multi-context format)
|
||||
source_code = """
|
||||
from unstructured.documents.elements import Element, Text, ListItem, PageBreak
|
||||
from unstructured.partition.common.common import normalize_layout_element
|
||||
|
||||
def document_to_element_list(document, sortable=False, include_page_breaks=False,
|
||||
last_modification_date=None, detection_origin=None,
|
||||
starting_page_number=1, infer_list_items=True, **kwargs):
|
||||
elements = []
|
||||
for page_idx, page in enumerate(document.pages):
|
||||
page_number = starting_page_number + page_idx
|
||||
for layout_element in page.elements_array.iter_elements():
|
||||
element = normalize_layout_element(layout_element)
|
||||
if isinstance(element, list):
|
||||
elements.extend(element)
|
||||
else:
|
||||
elements.append(element)
|
||||
if include_page_breaks and page_idx < len(document.pages):
|
||||
elements.append(PageBreak())
|
||||
return elements
|
||||
"""
|
||||
|
||||
# Source code blocks from unstructured codebase
|
||||
source_code_blocks = {
|
||||
"unstructured/partition/pdf.py": source_code,
|
||||
"unstructured/documents/elements.py": """
|
||||
from typing import Any, Optional
|
||||
|
||||
class ElementMetadata:
|
||||
def __init__(self):
|
||||
self.page_number = None
|
||||
self.parent_id = None
|
||||
self.coordinates = None
|
||||
self.last_modified = None
|
||||
|
||||
class Element:
|
||||
def __init__(self, text=""):
|
||||
self.text = text
|
||||
self.metadata = ElementMetadata()
|
||||
self.id = id(self)
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
class Text(Element):
|
||||
pass
|
||||
|
||||
class ListItem(Element):
|
||||
pass
|
||||
|
||||
class PageBreak(Element):
|
||||
pass
|
||||
|
||||
class Title(Element):
|
||||
pass
|
||||
|
||||
class ElementType:
|
||||
TEXT = "Text"
|
||||
LIST = "List"
|
||||
TITLE = "Title"
|
||||
""",
|
||||
}
|
||||
|
||||
expected = r"""from unstructured.partition.pdf import document_to_element_list
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unstructured.documents.elements import (Element, ElementType, ListItem,
|
||||
PageBreak, Text, Title)
|
||||
|
||||
|
||||
def make_layout_element_dict_like(
|
||||
*,
|
||||
text: str = "",
|
||||
element_type: str | None = None,
|
||||
coordinates: tuple | None = None,
|
||||
prob: float | None = None,
|
||||
bbox_x1: float = 0.0,
|
||||
bbox_x2: float = 1.0,
|
||||
bbox_y1: float = 0.0,
|
||||
bbox_y2: float = 1.0,
|
||||
parent: object | None = None,
|
||||
):
|
||||
bbox = SimpleNamespace(x1=bbox_x1, x2=bbox_x2, y1=bbox_y1, y2=bbox_y2)
|
||||
def to_dict():
|
||||
out = {"text": text}
|
||||
if element_type is not None:
|
||||
out["type"] = element_type
|
||||
if coordinates is not None:
|
||||
out["coordinates"] = coordinates
|
||||
if prob is not None:
|
||||
out["prob"] = prob
|
||||
return out
|
||||
le = SimpleNamespace(
|
||||
bbox=bbox,
|
||||
to_dict=to_dict,
|
||||
parent=parent,
|
||||
)
|
||||
return le
|
||||
|
||||
|
||||
def make_page(elements, *, image_metadata=None, image=None):
|
||||
class ElementsArray:
|
||||
def __init__(self, elements):
|
||||
self._elements = elements
|
||||
self.element_class_id_map = {}
|
||||
self.element_class_ids = np.array([], dtype=int)
|
||||
|
||||
def iter_elements(self):
|
||||
for el in self._elements:
|
||||
yield el
|
||||
|
||||
elements_array = ElementsArray(elements)
|
||||
page = SimpleNamespace(
|
||||
elements_array=elements_array,
|
||||
image_metadata=image_metadata,
|
||||
image=image,
|
||||
)
|
||||
return page
|
||||
|
||||
|
||||
def test_single_text_element_basic():
|
||||
coords = ((0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0))
|
||||
layout_el = make_layout_element_dict_like(
|
||||
text="Hello World",
|
||||
element_type=ElementType.TEXT,
|
||||
coordinates=coords,
|
||||
bbox_x1=5.0,
|
||||
)
|
||||
image_metadata = {"format": "PNG", "width": 100, "height": 200}
|
||||
page = make_page([layout_el], image_metadata=image_metadata)
|
||||
document = SimpleNamespace(pages=[page])
|
||||
|
||||
codeflash_output = document_to_element_list(
|
||||
document,
|
||||
sortable=False,
|
||||
include_page_breaks=False,
|
||||
last_modification_date="2022-01-01",
|
||||
detection_origin="detected-by-model",
|
||||
); elements = codeflash_output
|
||||
el = elements[0]
|
||||
|
||||
|
||||
def test_list_items_are_inferred():
|
||||
list_text = "1. First item\n2. Second item"
|
||||
layout_el = make_layout_element_dict_like(
|
||||
text=list_text,
|
||||
element_type=ElementType.LIST,
|
||||
coordinates=None,
|
||||
bbox_x1=5.0,
|
||||
)
|
||||
page = make_page([layout_el], image_metadata={"format": "JPEG", "width": 10, "height": 20})
|
||||
document = SimpleNamespace(pages=[page])
|
||||
|
||||
codeflash_output = document_to_element_list(
|
||||
document,
|
||||
infer_list_items=True,
|
||||
last_modification_date="2023-07-07",
|
||||
); elements = codeflash_output
|
||||
for item in elements:
|
||||
pass
|
||||
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="document_to_element_list",
|
||||
file_path="unstructured/partition/pdf.py",
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
module_path = "unstructured.partition.pdf"
|
||||
python_version = (3, 11)
|
||||
|
||||
# Step 1: Validate testgen code (simulates what happens after LLM response)
|
||||
validated_code = validate_testgen_code(raw_test_code, python_version)
|
||||
|
||||
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
|
||||
source_code_being_tested = group_code(source_code_blocks)
|
||||
processed_module = postprocessing_testgen_pipeline(
|
||||
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
|
||||
)
|
||||
|
||||
assert processed_module.code == expected
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ dependencies = [
|
|||
{ name = "sentry-sdk", extra = ["django"] },
|
||||
{ name = "stamina" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
|
|
@ -176,6 +177,7 @@ requires-dist = [
|
|||
{ name = "sentry-sdk", extras = ["django"], specifier = ">=2.35.0" },
|
||||
{ name = "stamina", specifier = ">=25.1.0" },
|
||||
{ name = "uvicorn", specifier = ">=0.32.0,<0.33" },
|
||||
{ name = "wcwidth", specifier = ">=0.2.15" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
|
|
@ -1882,11 +1884,11 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.14"
|
||||
version = "0.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/64/6e/62daec357285b927e82263a81f3b4c1790215bc77c42530ce4a69d501a43/wcwidth-0.5.0.tar.gz", hash = "sha256:f89c103c949a693bf563377b2153082bf58e309919dfb7f27b04d862a0089333", size = 246585, upload-time = "2026-01-27T01:31:44.942Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/3e/45583b67c2ff08ad5a582d316fcb2f11d6cf0a50c7707ac09d212d25bc98/wcwidth-0.5.0-py3-none-any.whl", hash = "sha256:1efe1361b83b0ff7877b81ba57c8562c99cf812158b778988ce17ec061095695", size = 93772, upload-time = "2026-01-27T01:31:43.432Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -327,11 +327,6 @@ export async function suggestPrChanges(
|
|||
}
|
||||
}
|
||||
|
||||
// Check if the owner is roboflow
|
||||
if (owner === "roboflow") {
|
||||
logger.info(`Rejecting request for roboflow repository`, req)
|
||||
throw unauthorized("Unauthorized for roboflow repositories")
|
||||
}
|
||||
// No approval required, proceed with PR suggestion
|
||||
const result = await triggerSuggestPrChanges(
|
||||
owner,
|
||||
|
|
|
|||
|
|
@ -168,34 +168,6 @@ describe("Suggest PR Changes", () => {
|
|||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it("should return 401 for roboflow repositories", async () => {
|
||||
mockReq.body.owner = "roboflow"
|
||||
mockDependencies.userNickname.mockResolvedValue("test-user")
|
||||
const mockInstallationOctokit = {
|
||||
rest: {
|
||||
pulls: {
|
||||
get: jest.fn() as any,
|
||||
},
|
||||
},
|
||||
}
|
||||
mockInstallationOctokit.rest.pulls.get.mockResolvedValue({
|
||||
data: { head: { ref: "feature-branch" } },
|
||||
})
|
||||
mockDependencies.getInstallationOctokitByOwner.mockResolvedValue(mockInstallationOctokit)
|
||||
mockDependencies.isUserCollaborator.mockResolvedValue(true)
|
||||
mockDependencies.requiresApproval.mockReturnValue(false)
|
||||
|
||||
const consoleSpy = jest.spyOn(console, "log").mockImplementation(() => {})
|
||||
|
||||
await expect(
|
||||
suggestPrChanges(mockReq as AuthorizedUserReq, mockRes as Response),
|
||||
).rejects.toMatchObject({
|
||||
message: expect.stringContaining("Unauthorized for roboflow repositories"),
|
||||
})
|
||||
expect(mockRes.status).not.toHaveBeenCalled()
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
})
|
||||
|
||||
describe("approval workflow", () => {
|
||||
|
|
|
|||
Loading…
Reference in a new issue