Move AI-generated test instrumentation from server-side to client-side

Server-side instrumentation wrote return values to .bin files, which
corrupted under concurrent pytest processes (interleaved records →
UnicodeDecodeError). Client-side instrumentation writes to SQLite,
which handles concurrent access safely.

The client now ignores instrumented_behavior_tests and
instrumented_perf_tests from the AI service response and instruments
the plain generated_tests locally using inject_profiling_into_existing_test,
the same path used for discovered existing tests.
This commit is contained in:
Kevin Turcios 2026-04-21 07:38:48 -05:00
parent bc0323a46c
commit 434e888571
5 changed files with 68 additions and 75 deletions

View file

@ -30,13 +30,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# Type alias for pending test entries produced during AI test
# generation. Previously a class variable on the optimizer.
PendingTest = tuple[
int, # test_index
str, # generated_source
str, # behavior_source
str, # perf_source
"Path", # test_path
"Path", # test_perf_path
]
@ -353,9 +349,10 @@ def generate_ai_tests( # noqa: PLR0913
) -> list[TestFile]:
"""Generate regression tests via the AI service.
Creates test files with pre-instrumented behavior and
performance variants. Returns a list of *TestFile* objects
ready to be appended to ``test_files``.
Creates test files and instruments them client-side for
behavior capture and performance measurement. Returns a
list of *TestFile* objects ready to be appended to
``test_files``.
"""
import tempfile # noqa: PLC0415
from pathlib import Path as _Path # noqa: PLC0415
@ -442,9 +439,9 @@ def generate_ai_tests( # noqa: PLR0913
if result is None:
continue
gen_src, beh_src, perf_src, _raw, tp, tpp = result
gen_src, _raw, tp, tpp = result
pending.append(
(test_index, gen_src, beh_src, perf_src, tp, tpp),
(test_index, gen_src, tp, tpp),
)
if not pending:
@ -461,47 +458,55 @@ def generate_ai_tests( # noqa: PLR0913
fn_input=fn_input,
)
# Phase 4: write files and create TestFile objects.
import pathlib # noqa: PLC0415
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
get_run_tmp_file,
# Phase 4: write files, instrument client-side, create TestFile objects.
from .._model import TestingMode # noqa: PLC0415
from ..testing._instrumentation import ( # noqa: PLC0415
inject_profiling_into_existing_test,
)
tmp_dir = get_run_tmp_file(pathlib.Path()).as_posix()
test_file_objects: list[TestFile] = []
for (
_idx,
generated_source,
behavior_source,
perf_source,
test_path,
test_perf_path,
) in pending:
resolved_behavior = behavior_source.replace(
"{codeflash_run_tmp_dir_client_side}", tmp_dir,
)
resolved_perf = perf_source.replace(
"{codeflash_run_tmp_dir_client_side}", tmp_dir,
)
test_path.write_text(generated_source, encoding="utf-8")
beh_path = test_path.parent / (
ok_beh, beh_src = inject_profiling_into_existing_test(
test_path=test_path,
call_positions=[],
function_to_optimize=func,
tests_project_root=tests_rootdir,
mode=TestingMode.BEHAVIOR,
)
ok_perf, perf_src = inject_profiling_into_existing_test(
test_path=test_path,
call_positions=[],
function_to_optimize=func,
tests_project_root=tests_rootdir,
mode=TestingMode.PERFORMANCE,
)
beh_path: _Path | None = test_path.parent / (
test_path.stem + "__perfinstrumented" + test_path.suffix
)
beh_path.write_text(resolved_behavior, encoding="utf-8")
perf_path: _Path | None = test_perf_path
test_perf_path.write_text(
resolved_perf,
encoding="utf-8",
)
if ok_beh and beh_src is not None:
beh_path.write_text(beh_src, encoding="utf-8") # type: ignore[union-attr]
else:
beh_path = None
if ok_perf and perf_src is not None:
perf_path.write_text(perf_src, encoding="utf-8") # type: ignore[union-attr]
else:
perf_path = None
test_file_objects.append(
TestFile(
original_file_path=test_path,
instrumented_behavior_file_path=beh_path,
benchmarking_file_path=test_perf_path,
benchmarking_file_path=perf_path,
test_type=TestType.GENERATED_REGRESSION,
),
)
@ -593,8 +598,6 @@ def review_and_repair_tests( # noqa: C901, PLR0913
(
tidx,
gen_src,
_beh,
_perf,
test_path,
test_perf_path,
) = entry
@ -645,12 +648,9 @@ def review_and_repair_tests( # noqa: C901, PLR0913
if repair_result is None:
continue
repaired_gen, repaired_beh, repaired_perf = repair_result
pending[pos] = (
tidx,
repaired_gen,
repaired_beh,
repaired_perf,
repair_result,
test_path,
test_perf_path,
)

View file

@ -45,10 +45,12 @@ def get_call_arguments(
def node_in_call_position(
node: ast.AST, call_positions: list[CodePosition]
) -> bool:
"""Return True if the AST node overlaps any of the given call positions."""
# Reduce attribute lookup and localize call_positions
# if not empty for a meaningful speedup.
# Small optimizations for tight loop:
"""Return True if the AST node overlaps any of the given call positions.
When *call_positions* is empty, every ``ast.Call`` node matches.
"""
if not call_positions:
return isinstance(node, ast.Call)
if isinstance(node, ast.Call):
node_lineno = getattr(node, "lineno", None)
node_col_offset = getattr(node, "col_offset", None)

View file

@ -104,24 +104,21 @@ class TestgenPayload:
def generate_regression_tests(
client: AIClient,
payload: TestgenPayload,
) -> tuple[str, str, str, str | None] | None:
) -> tuple[str, str | None] | None:
"""
Call the AI service ``/ai/testgen`` endpoint to generate regression tests.
Returns *(generated_tests, instrumented_behavior, instrumented_perf,
raw_generated)* or ``None`` on failure.
Returns *(generated_tests, raw_generated)* or ``None`` on failure.
"""
data = client.post("/testgen", payload.to_dict())
generated = data.get("generated_tests", "")
behavior = data.get("instrumented_behavior_tests", "")
perf = data.get("instrumented_perf_tests", "")
raw = data.get("raw_generated_tests")
if not generated:
return None
return (generated, behavior, perf, raw)
return (generated, raw)
def generate_tests( # noqa: PLR0913
@ -139,12 +136,12 @@ def generate_tests( # noqa: PLR0913
test_module_path: str,
language_version: str,
is_numerical_code: bool | None = None, # noqa: FBT001
) -> tuple[str, str, str, str | None, Path, Path] | None:
) -> tuple[str, str | None, Path, Path] | None:
"""
Generate regression tests for a function via the AI service.
Returns *(generated_source, behavior_source, perf_source, raw_source,
test_path, test_perf_path)* or ``None`` on failure.
Returns *(generated_source, raw_source, test_path, test_perf_path)*
or ``None`` on failure.
"""
payload = TestgenPayload(
source_code_being_tested=source_code_being_tested,
@ -171,8 +168,8 @@ def generate_tests( # noqa: PLR0913
if response is None:
return None
generated, behavior, perf, raw = response
return (generated, behavior, perf, raw, test_path, test_perf_path)
generated, raw = response
return (generated, raw, test_path, test_perf_path)
def review_generated_tests(
@ -195,23 +192,20 @@ def review_generated_tests(
def repair_generated_tests(
client: AIClient,
payload: dict[str, Any],
) -> tuple[str, str, str] | None:
) -> str | None:
"""
Repair generated tests via the AI service.
Returns *(generated_tests, instrumented_behavior, instrumented_perf)*
or ``None`` on failure.
Returns the repaired test source or ``None`` on failure.
"""
try:
data = client.post("/testgen_repair", payload)
except (AIServiceError, AIServiceConnectionError):
return None
generated = data.get("generated_tests", "")
behavior = data.get("instrumented_behavior_tests", "")
perf = data.get("instrumented_perf_tests", "")
if not generated:
return None
return (generated, behavior, perf)
return generated
class ModifyInspiredTests(ast.NodeTransformer):

View file

@ -139,12 +139,12 @@ class TestNodeInCallPosition:
positions = [CodePosition(line_no=99, col_no=99)]
assert node_in_call_position(call_node, positions) is False
def test_empty_positions(self) -> None:
"""Returns False when positions list is empty."""
def test_empty_positions_matches_all(self) -> None:
"""Returns True for any Call when positions list is empty."""
code = "target_func()\n"
tree = ast.parse(code)
call_node = tree.body[0].value # type: ignore[attr-defined]
assert node_in_call_position(call_node, []) is False
assert node_in_call_position(call_node, []) is True
def test_multiple_positions_one_match(self) -> None:
"""Returns True when one of several positions matches."""

View file

@ -293,12 +293,10 @@ class TestGenerateRegressionTests:
"""generate_regression_tests AI service call."""
def test_successful_response(self) -> None:
"""Successful response returns tuple of 4 strings."""
"""Successful response returns tuple of generated and raw sources."""
client = make_mock_client()
client.post.return_value = {
"generated_tests": "test code",
"instrumented_behavior_tests": "behavior code",
"instrumented_perf_tests": "perf code",
"raw_generated_tests": "raw code",
}
@ -310,8 +308,8 @@ class TestGenerateRegressionTests:
payload=payload,
)
assert result is not None
assert 4 == len(result)
assert "test code" == result[0]
assert 2 == len(result)
assert ("test code", "raw code") == result
def test_empty_generated_source_returns_none(self) -> None:
"""Empty generated_test_source returns None."""
@ -391,20 +389,17 @@ class TestRepairGeneratedTests:
"""repair_generated_tests AI service call."""
def test_successful_response(self) -> None:
"""Successful response returns tuple of 3 strings."""
"""Successful response returns the repaired test source."""
client = make_mock_client()
client.post.return_value = {
"generated_tests": "fixed tests",
"instrumented_behavior_tests": "fixed behavior",
"instrumented_perf_tests": "fixed perf",
}
result = repair_generated_tests(
client, {"test_source": "x", "trace_id": "t1"}
)
assert result is not None
assert ("fixed tests", "fixed behavior", "fixed perf") == result
assert "fixed tests" == result
def test_failure_returns_none(self) -> None:
"""API error returns None."""
@ -438,11 +433,9 @@ class TestGenerateTests:
mock_regression: MagicMock,
tmp_path: Path,
) -> None:
"""Successful flow returns tuple of 6 items."""
"""Successful flow returns tuple of 4 items."""
mock_regression.return_value = (
"generated",
"behavior",
"perf",
"raw",
)
client = make_mock_client()
@ -466,7 +459,11 @@ class TestGenerateTests:
language_version="3.12.0",
)
assert result is not None
assert 6 == len(result)
assert 4 == len(result)
assert "generated" == result[0]
assert "raw" == result[1]
assert test_path == result[2]
assert test_perf_path == result[3]
@patch("codeflash_python.testing._testgen.generate_regression_tests")
def test_none_from_api_returns_none(