Merge pull request #2465 from codeflash-ai/testgen-review-repair

feat: per-function test review + repair endpoints
This commit is contained in:
Kevin Turcios 2026-03-05 22:37:21 +00:00 committed by GitHub
commit 8d1dfd9bdb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 556 additions and 30 deletions

View file

@ -30,6 +30,8 @@ from core.languages.python.optimizer.refinement import refinement_api
from core.log_features.log_features import features_api
from core.shared.optimizer_router import optimize_api
from core.shared.ranker.ranker import ranker_api
from core.shared.testgen_review.repair import testgen_repair_api
from core.shared.testgen_review.review import testgen_review_api
from core.shared.testgen_router import testgen_api
from core.shared.workflow_gen.workflow_gen import workflow_gen_api
@ -37,6 +39,8 @@ urlpatterns = [
path("ai/optimize", optimize_api.urls),
path("ai/optimize-line-profiler", optimize_line_profiler_api.urls),
path("ai/testgen", testgen_api.urls),
path("ai/testgen_review", testgen_review_api.urls),
path("ai/testgen_repair", testgen_repair_api.urls),
path("ai/log_features", features_api.urls),
path("ai/refinement", refinement_api.urls),
path("ai/explain", explanations_api.urls),

View file

@ -5,7 +5,7 @@ import ast
import asyncio
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import libcst as cst
import sentry_sdk
@ -71,7 +71,7 @@ def build_prompt(
is_async: bool,
is_numerical_code: bool | None = None,
model_type: str = "openai",
) -> tuple[list[dict[str, str]], str, str]:
) -> tuple[list[ChatCompletionMessageParam], str, str]:
system_template = "generate_async_system.md.j2" if is_async else "generate_system.md.j2"
system_prompt = _jinja_env.get_template(system_template).render(
@ -89,7 +89,10 @@ def build_prompt(
posthog_event_suffix = "async-" if is_async else ""
error_context = "async " if is_async else ""
execute_messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
execute_messages: list[ChatCompletionMessageParam] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
return execute_messages, posthog_event_suffix, error_context
@ -163,7 +166,7 @@ async def generate_and_validate_test_code(
is_async: bool = False,
test_index: int | None = None,
) -> tuple[str, str]:
obs_context: dict | None = (
obs_context: dict[str, Any] | None = (
{
"call_sequence": call_sequence,
"test_index": test_index,
@ -237,7 +240,7 @@ async def generate_regression_tests_from_function(
call_sequence: int | None = None,
test_index: int | None = None,
model_type: str = "openai",
) -> tuple[str, str | None, str | None, str]:
) -> tuple[str, str, str, str]:
execute_messages, posthog_event_suffix, error_context = build_prompt(
qualified_name=qualified_name,
source_code=source_code,
@ -251,7 +254,7 @@ async def generate_regression_tests_from_function(
cost_tracker = CostTracker()
try:
validated_code, raw_llm_content = await generate_and_validate_test_code(
validated_code, _raw_llm_content = await generate_and_validate_test_code(
messages=execute_messages,
model=execute_model,
source_code=source_code,
@ -270,7 +273,7 @@ async def generate_regression_tests_from_function(
test_index=test_index,
)
processed_cst = postprocessing_testgen_pipeline(
processed_cst, display_cst = postprocessing_testgen_pipeline(
parse_module_to_cst(validated_code),
data.helper_function_names or [],
data.function_to_optimize,
@ -279,6 +282,7 @@ async def generate_regression_tests_from_function(
)
generated_test_source = processed_cst.code
raw_display_source = display_cst.code
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
generated_test_source, data, python_version
@ -317,7 +321,7 @@ async def generate_regression_tests_from_function(
"validation_error": "No test functions found after postprocessing",
},
)
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests, raw_llm_content # noqa: TRY300
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests, raw_display_source # noqa: TRY300
except CodeValidationError as e:
msg = f"Failed to generate valid {error_context}test code after {cost_tracker.calls} tries. trace_id={trace_id}"
logging.exception(msg)
@ -380,7 +384,7 @@ async def testgen_python(
generated_test_source,
instrumented_behavior_tests,
instrumented_perf_tests,
raw_llm_content,
raw_display_source,
) = await generate_regression_tests_from_function(
source_code=data.source_code_being_tested,
qualified_name=data.function_to_optimize.qualified_name,
@ -389,7 +393,7 @@ async def testgen_python(
python_version=python_version,
data=data,
unit_test_package=data.test_framework,
is_async=data.function_to_optimize.is_async or data.is_async,
is_async=data.function_to_optimize.is_async or data.is_async or False,
trace_id=data.trace_id,
call_sequence=data.call_sequence,
execute_model=execute_model,
@ -403,7 +407,7 @@ async def testgen_python(
await log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[raw_llm_content],
generated_tests=[raw_display_source],
instrumented_generated_tests=[instrumented_behavior_tests],
instrumented_perf_tests=[instrumented_perf_tests],
test_index=test_index,
@ -419,6 +423,7 @@ async def testgen_python(
generated_tests=generated_test_source,
instrumented_behavior_tests=instrumented_behavior_tests,
instrumented_perf_tests=instrumented_perf_tests,
raw_generated_tests=raw_display_source,
)
except TestGenerationFailedError as e:

View file

@ -31,16 +31,19 @@ def postprocessing_testgen_pipeline(
function_to_optimize: FunctionToOptimize,
module_path: str,
source_code_being_tested: str,
) -> Module:
) -> tuple[Module, Module]:
"""Full postprocessing pipeline for generated test code.
Applies all CST transformations in sequence:
1. Clean up definitions (remove helper functions, unused definitions)
2. Modify constructs (large loops, tensors)
3. Remove asserts
4. Add missing imports
5. Replace function definition with import
3. Add missing imports, replace function definition with import
4. Remove asserts (for instrumentation only)
Returns (processed_module_with_asserts_removed, processed_module_with_asserts_kept).
The second module is the "display" version fully cleaned but with original asserts intact.
"""
source_cst: Module | dict[str, Module]
if is_multi_context(source_code_being_tested):
source_cst = {
path: cst.parse_module(code) for path, code in split_markdown_code(source_code_being_tested).items()
@ -48,24 +51,27 @@ def postprocessing_testgen_pipeline(
else:
source_cst = cst.parse_module(source_code_being_tested)
pipeline: list[tuple[Callable[..., Module], dict[str, Any]]] = [
cleanup_pipeline: list[tuple[Callable[..., Module], dict[str, Any]]] = [
(delete_top_def_nodes, {"deletable_list": helper_function_names}),
(remove_unused_definitions_from_pytest_file, {}),
(modify_large_loops, {}),
(modify_tensors, {}),
(
remove_asserts_from_test,
{
"function_to_optimize": function_to_optimize,
"helper_function_names": helper_function_names,
"module_path": module_path,
},
),
(add_missing_imports, {"source_cst": source_cst, "module_path": module_path}),
(replace_definition_with_import, {"function": function_to_optimize, "module_path": module_path}),
]
for func, kwargs in pipeline:
for func, kwargs in cleanup_pipeline:
module = func(module, **kwargs)
return module
# Capture the display version (with asserts) before removing them
display_module = module
# Remove asserts for the instrumentation version
module = remove_asserts_from_test(
module,
function_to_optimize=function_to_optimize,
helper_function_names=helper_function_names,
module_path=module_path,
)
return module, display_module

View file

@ -44,6 +44,7 @@ class TestGenResponseSchema(Schema):
generated_tests: str
instrumented_behavior_tests: str
instrumented_perf_tests: str
raw_generated_tests: str | None = None
class TestGenDebugInfo(Schema):

View file

@ -0,0 +1,79 @@
"""Request/response schemas for the testgen review and repair endpoints."""
from ninja import Schema
from aiservice.models.functions_to_optimize import FunctionToOptimize
# ── Review ──────────────────────────────────────────────────────────
class TestSourceWithFailures(Schema):
test_source: str
test_index: int
failed_test_functions: list[str] = [] # functions that failed behavioral tests
failure_messages: dict[str, str] = {} # function_name → error message from test run
class TestgenReviewSchema(Schema):
tests: list[TestSourceWithFailures]
function_source_code: str
function_name: str
trace_id: str
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
coverage_summary: str = ""
class FunctionVerdict(Schema):
function_name: str
verdict: str # "pass" or "repair"
reason: str = ""
class TestReview(Schema):
test_index: int
functions: list[FunctionVerdict]
class TestgenReviewResponseSchema(Schema):
reviews: list[TestReview]
class TestgenReviewErrorSchema(Schema):
error: str
# ── Repair ──────────────────────────────────────────────────────────
class FunctionToRepair(Schema):
function_name: str
reason: str
class TestRepairSchema(Schema):
test_source: str
functions_to_repair: list[FunctionToRepair]
function_source_code: str
function_to_optimize: FunctionToOptimize
helper_function_names: list[str] | None = None
module_path: str
test_module_path: str
test_framework: str
test_timeout: int
trace_id: str
python_version: str | None = None
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
class TestRepairResponseSchema(Schema):
generated_tests: str
instrumented_behavior_tests: str
instrumented_perf_tests: str
class TestRepairErrorSchema(Schema):
error: str

View file

@ -0,0 +1,17 @@
You are a test repair assistant for an AI code optimizer.
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
Constraints for repaired functions:
- Use diverse, realistic inputs that reflect production usage patterns
- Only use the function's public API — do not access internal implementation details
- Exercise meaningful code paths
- Keep the same function name so test discovery still works
- Maintain the same testing framework conventions (pytest, unittest, etc.)
- Do NOT add comments, docstrings, or type annotations that weren't in the original
- The reason provided for each function describes the specific problem — fix that problem
Return the complete file in a single Python code block:
```python
# complete test file here
```

View file

@ -0,0 +1,19 @@
<function_under_test name="{{ function_name }}">
```python
{{ function_source_code }}
```
</function_under_test>
<test_file>
```python
{{ test_source }}
```
</test_file>
<functions_to_repair>
{% for fn in functions_to_repair -%}
- `{{ fn.function_name }}`: {{ fn.reason }}
{% endfor %}
</functions_to_repair>
Rewrite only the functions listed above. Return the complete file.

View file

@ -0,0 +1,38 @@
You are a test quality reviewer for CodeFlash, an AI code optimizer. CodeFlash uses generated tests to:
1. **Verify correctness** — run the same tests on both original and optimized code, comparing outputs to catch regressions
2. **Benchmark performance** — time the function under test to measure speedup
Review each test function and flag ones that would compromise correctness verification or produce misleading benchmarks.
Flag a test function if it has any of these problems:
1. **Doesn't exercise the function**: The test never calls the target function, or only triggers trivial/no-op code paths (empty input, immediate error return). If coverage information is provided, use it to identify tests that miss the function entirely.
2. **Non-deterministic behavior**: The test produces different outputs across runs due to randomness, timestamps, or external state. This causes false correctness failures when comparing original vs optimized outputs.
3. **Internal state coupling**: The test asserts on internal implementation details (private attributes, internal data structures) rather than observable outputs. Optimized code may change internals while preserving behavior, causing false failures.
4. **Identical-input repetition**: Calling the function with the same arguments multiple times inflates cache hit rates and produces speedups that vanish in production.
5. **Unrealistic data patterns**: Only a single type or size of input, missing the diversity of real-world usage that would reveal performance regressions.
Only flag clear problems. If a test function looks reasonable, omit it from the output.
Respond with a JSON code block containing only the functions that need repair:
```json
{
"functions": [
{"function_name": "test_example", "verdict": "repair", "reason": "One sentence explaining the specific problem"}
]
}
```
If all functions are acceptable:
```json
{"functions": []}
```
Keep reasons to one sentence — they are passed directly to the repair LLM as instructions.

View file

@ -0,0 +1,25 @@
<function_under_test name="{{ function_name }}">
```python
{{ function_source_code }}
```
</function_under_test>
<generated_test>
```python
{{ test_source }}
```
</generated_test>
{% if coverage_summary %}
<coverage>
Test coverage of {{ function_name }}: {{ coverage_summary }}
</coverage>
{% endif %}
{% if failed_note %}
<note>
{{ failed_note }}
</note>
{% endif %}
Review each test function in the generated test. Return your verdict as JSON.

View file

@ -0,0 +1,158 @@
"""Test repair endpoint.
Takes a generated test file with specific flagged functions and returns a
repaired version with those functions rewritten.
Currently enabled for Python only.
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
import sentry_sdk
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block_with_context
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
testgen_repair_api = NinjaAPI(urls_namespace="testgen_repair")
_prompts_dir = Path(__file__).parent / "prompts"
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
@testgen_repair_api.post(
"/", response={200: TestRepairResponseSchema, 400: TestRepairErrorSchema, 500: TestRepairErrorSchema}
)
async def testgen_repair(
request: AuthenticatedRequest, data: TestRepairSchema
) -> tuple[int, TestRepairResponseSchema | TestRepairErrorSchema]:
if data.language != "python":
return 400, TestRepairErrorSchema(error="Test repair is only supported for Python")
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
try:
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render()
user_prompt = _jinja_env.get_template("testgen_repair_user_prompt.md.j2").render(
function_name=data.function_to_optimize.function_name,
function_source_code=data.function_source_code,
test_source=data.test_source,
functions_to_repair=data.functions_to_repair,
)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
ChatCompletionUserMessageParam(role="user", content=user_prompt),
]
obs_context: dict[str, Any] = {}
if data.call_sequence is not None:
obs_context["call_sequence"] = data.call_sequence
response = await call_llm(
llm=EXECUTE_MODEL,
messages=messages,
call_type="testgen_repair",
trace_id=data.trace_id,
user_id=request.user,
context=obs_context,
)
cost = calculate_llm_cost(response.raw_response, EXECUTE_MODEL)
logging.debug(f"testgen_repair LLM cost: {cost}")
repair_text = response.content.strip()
if result := extract_code_block_with_context(repair_text, language="python"):
_before, repaired_code, _after = result
else:
return 500, TestRepairErrorSchema(error="Could not extract repaired code from LLM response")
# Splice only the flagged functions from the LLM output into the original test source,
# keeping all unflagged functions untouched
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
from core.languages.python.testgen.instrumentation.edit_generated_test import find_and_replace_function
from core.languages.python.testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
from core.languages.python.testgen.validate import instrument_tests, validate_request_data # noqa: PLC0415
from core.shared.testgen_models import TestGenSchema # noqa: PLC0415
repaired_cst = parse_module_to_cst(repaired_code)
original_cst = parse_module_to_cst(data.test_source)
# Extract repaired function nodes by name
repaired_functions: dict[str, Any] = {}
for node in repaired_cst.body:
if hasattr(node, "body") and hasattr(node.body, "__iter__"):
for child in node.body:
if hasattr(child, "name") and hasattr(child.name, "value"):
repaired_functions[child.name.value] = child
if hasattr(node, "name") and hasattr(node.name, "value"):
repaired_functions[node.name.value] = node
# Replace only flagged functions in the original
merged_cst = original_cst
for func_to_repair in data.functions_to_repair:
if func_to_repair.function_name in repaired_functions:
merged_cst = find_and_replace_function(
merged_cst, func_to_repair.function_name, repaired_functions[func_to_repair.function_name]
)
# Run postprocessing and instrumentation on the merged result
testgen_data = TestGenSchema(
source_code_being_tested=data.function_source_code,
function_to_optimize=data.function_to_optimize,
helper_function_names=data.helper_function_names or [],
module_path=data.module_path,
test_module_path=data.test_module_path,
test_framework=data.test_framework,
test_timeout=data.test_timeout,
trace_id=data.trace_id,
python_version=data.python_version,
language=data.language,
codeflash_version=data.codeflash_version,
)
python_version = validate_request_data(testgen_data)
processed_cst, display_cst = postprocessing_testgen_pipeline(
merged_cst,
data.helper_function_names or [],
data.function_to_optimize,
data.module_path,
data.function_source_code,
)
processed_code = processed_cst.code
display_code = display_cst.code
instrumented_behavior, instrumented_perf = instrument_tests(processed_code, testgen_data, python_version)
if instrumented_behavior is None or instrumented_perf is None:
return 500, TestRepairErrorSchema(error="Failed to instrument repaired tests")
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-completed")
return 200, TestRepairResponseSchema(
generated_tests=display_code,
instrumented_behavior_tests=instrumented_behavior,
instrumented_perf_tests=instrumented_perf,
)
except HttpError:
raise
except Exception as e:
logging.exception("Error in testgen_repair")
sentry_sdk.capture_exception(e)
return 500, TestRepairErrorSchema(error="Internal server error")

View file

@ -0,0 +1,174 @@
"""Test quality review endpoint.
Reviews AI-generated test sources for unrealistic patterns (cache warm-up,
internal state manipulation, identical inputs, etc.) and returns per-function
verdicts of "pass" or "repair" with reasons.
Currently enabled for Python only other languages return all-pass immediately.
"""
from __future__ import annotations
import asyncio
import json
import logging
from pathlib import Path
from typing import Any
import sentry_sdk
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja import NinjaAPI
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block_with_context
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_review.models import (
FunctionVerdict,
TestgenReviewErrorSchema,
TestgenReviewResponseSchema,
TestgenReviewSchema,
TestReview,
)
testgen_review_api = NinjaAPI(urls_namespace="testgen_review")
_prompts_dir = Path(__file__).parent / "prompts"
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
@testgen_review_api.post(
"/", response={200: TestgenReviewResponseSchema, 400: TestgenReviewErrorSchema, 500: TestgenReviewErrorSchema}
)
async def testgen_review(
request: AuthenticatedRequest, data: TestgenReviewSchema
) -> tuple[int, TestgenReviewResponseSchema | TestgenReviewErrorSchema]:
# Only enabled for Python for now
if data.language != "python":
return 200, TestgenReviewResponseSchema(reviews=[])
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
try:
reviews: list[TestReview] = []
for test_entry in data.tests:
review = await _review_single_test(
test_source=test_entry.test_source,
test_index=test_entry.test_index,
failed_test_functions=test_entry.failed_test_functions,
failure_messages=test_entry.failure_messages,
function_source_code=data.function_source_code,
function_name=data.function_name,
trace_id=data.trace_id,
user_id=request.user,
call_sequence=data.call_sequence,
coverage_summary=data.coverage_summary,
)
reviews.append(review)
await asyncio.to_thread(
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
)
return 200, TestgenReviewResponseSchema(reviews=reviews)
except Exception as e:
logging.exception("Error in testgen_review")
sentry_sdk.capture_exception(e)
return 500, TestgenReviewErrorSchema(error="Internal server error")
async def _review_single_test(
test_source: str,
test_index: int,
failed_test_functions: list[str],
failure_messages: dict[str, str],
function_source_code: str,
function_name: str,
trace_id: str,
user_id: str,
call_sequence: int | None,
coverage_summary: str = "",
) -> TestReview:
# Functions that failed behavioral tests are definitively bad — include error details in reason
failed_verdicts = [
FunctionVerdict(
function_name=fn,
verdict="repair",
reason=failure_messages.get(fn, "Failed behavioral test against original code"),
)
for fn in failed_test_functions
]
# Build prompt for AI review of passing functions — include error context so the AI can see patterns
failed_note = ""
if failed_test_functions:
failed_lines = []
for fn in failed_test_functions:
error_msg = failure_messages.get(fn, "")
if error_msg:
failed_lines.append(f"- `{fn}`: {error_msg}")
else:
failed_lines.append(f"- `{fn}`")
failed_note = (
"Note: The following functions failed behavioral tests and are already flagged for repair. "
"Do NOT include them in your review — only review the remaining functions.\n" + "\n".join(failed_lines)
)
system_prompt = _jinja_env.get_template("testgen_review_system_prompt.md.j2").render()
user_prompt = _jinja_env.get_template("testgen_review_user_prompt.md.j2").render(
function_name=function_name,
function_source_code=function_source_code,
test_source=test_source,
failed_note=failed_note,
coverage_summary=coverage_summary,
)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
ChatCompletionUserMessageParam(role="user", content=user_prompt),
]
obs_context: dict[str, Any] = {}
if call_sequence is not None:
obs_context["call_sequence"] = call_sequence
response = await call_llm(
llm=HAIKU_MODEL,
messages=messages,
call_type="testgen_review",
trace_id=trace_id,
user_id=user_id,
context=obs_context,
)
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
logging.debug(f"testgen_review LLM cost: {cost}")
ai_verdicts = _parse_review_response(response.content.strip())
return TestReview(test_index=test_index, functions=failed_verdicts + ai_verdicts)
def _parse_review_response(review_text: str) -> list[FunctionVerdict]:
if result := extract_code_block_with_context(review_text, language="json"):
_before, json_content, _after = result
try:
data = json.loads(json_content.strip())
# Support both formats: {"functions": [...]} or [...]
fn_list = data if isinstance(data, list) else data.get("functions", [])
return [
FunctionVerdict(
function_name=fn.get("function_name", ""),
verdict=fn.get("verdict", "pass"),
reason=fn.get("reason", ""),
)
for fn in fn_list
if fn.get("verdict") == "repair"
]
except (json.JSONDecodeError, AttributeError):
logging.warning("testgen_review: failed to parse AI response JSON")
return []

View file

@ -43,7 +43,7 @@ def test_list_comparison():
)
source_code = "def comparator(orig, new):\n return orig == new"
result = postprocessing_testgen_pipeline(
result, _ = postprocessing_testgen_pipeline(
module=cst.parse_module(generated_code),
helper_function_names=[],
function_to_optimize=function_to_optimize,

View file

@ -84,7 +84,7 @@ def extract_input_variables(nodes):
ending_line=None,
)
module_path = "test_validate_pipeline"
result = postprocessing_testgen_pipeline(
result, _ = postprocessing_testgen_pipeline(
module, ["function_to_remove"], function_to_optimize, module_path, source_code_being_tested
)
@ -262,7 +262,7 @@ def test_document_to_element_list_single_element():
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
source_code_being_tested = group_code(source_code_blocks)
processed_module = postprocessing_testgen_pipeline(
processed_module, _ = postprocessing_testgen_pipeline(
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
)
@ -570,7 +570,7 @@ def test_list_items_are_inferred():
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
source_code_being_tested = group_code(source_code_blocks)
processed_module = postprocessing_testgen_pipeline(
processed_module, _ = postprocessing_testgen_pipeline(
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
)