Move AI-generated test instrumentation from server-side to client-side
Server-side instrumentation wrote return values to .bin files, which corrupted under concurrent pytest processes (interleaved records → UnicodeDecodeError). Client-side instrumentation writes to SQLite, which handles concurrent access safely. The client now ignores instrumented_behavior_tests and instrumented_perf_tests from the AI service response and instruments the plain generated_tests locally using inject_profiling_into_existing_test, the same path used for discovered existing tests.
This commit is contained in:
parent
bc0323a46c
commit
434e888571
5 changed files with 68 additions and 75 deletions
|
|
@ -30,13 +30,9 @@ if TYPE_CHECKING:
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for pending test entries produced during AI test
|
||||
# generation. Previously a class variable on the optimizer.
|
||||
PendingTest = tuple[
|
||||
int, # test_index
|
||||
str, # generated_source
|
||||
str, # behavior_source
|
||||
str, # perf_source
|
||||
"Path", # test_path
|
||||
"Path", # test_perf_path
|
||||
]
|
||||
|
|
@ -353,9 +349,10 @@ def generate_ai_tests( # noqa: PLR0913
|
|||
) -> list[TestFile]:
|
||||
"""Generate regression tests via the AI service.
|
||||
|
||||
Creates test files with pre-instrumented behavior and
|
||||
performance variants. Returns a list of *TestFile* objects
|
||||
ready to be appended to ``test_files``.
|
||||
Creates test files and instruments them client-side for
|
||||
behavior capture and performance measurement. Returns a
|
||||
list of *TestFile* objects ready to be appended to
|
||||
``test_files``.
|
||||
"""
|
||||
import tempfile # noqa: PLC0415
|
||||
from pathlib import Path as _Path # noqa: PLC0415
|
||||
|
|
@ -442,9 +439,9 @@ def generate_ai_tests( # noqa: PLR0913
|
|||
if result is None:
|
||||
continue
|
||||
|
||||
gen_src, beh_src, perf_src, _raw, tp, tpp = result
|
||||
gen_src, _raw, tp, tpp = result
|
||||
pending.append(
|
||||
(test_index, gen_src, beh_src, perf_src, tp, tpp),
|
||||
(test_index, gen_src, tp, tpp),
|
||||
)
|
||||
|
||||
if not pending:
|
||||
|
|
@ -461,47 +458,55 @@ def generate_ai_tests( # noqa: PLR0913
|
|||
fn_input=fn_input,
|
||||
)
|
||||
|
||||
# Phase 4: write files and create TestFile objects.
|
||||
import pathlib # noqa: PLC0415
|
||||
|
||||
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
|
||||
get_run_tmp_file,
|
||||
# Phase 4: write files, instrument client-side, create TestFile objects.
|
||||
from .._model import TestingMode # noqa: PLC0415
|
||||
from ..testing._instrumentation import ( # noqa: PLC0415
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
|
||||
tmp_dir = get_run_tmp_file(pathlib.Path()).as_posix()
|
||||
|
||||
test_file_objects: list[TestFile] = []
|
||||
for (
|
||||
_idx,
|
||||
generated_source,
|
||||
behavior_source,
|
||||
perf_source,
|
||||
test_path,
|
||||
test_perf_path,
|
||||
) in pending:
|
||||
resolved_behavior = behavior_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", tmp_dir,
|
||||
)
|
||||
resolved_perf = perf_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", tmp_dir,
|
||||
)
|
||||
test_path.write_text(generated_source, encoding="utf-8")
|
||||
|
||||
beh_path = test_path.parent / (
|
||||
ok_beh, beh_src = inject_profiling_into_existing_test(
|
||||
test_path=test_path,
|
||||
call_positions=[],
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tests_rootdir,
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
ok_perf, perf_src = inject_profiling_into_existing_test(
|
||||
test_path=test_path,
|
||||
call_positions=[],
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tests_rootdir,
|
||||
mode=TestingMode.PERFORMANCE,
|
||||
)
|
||||
|
||||
beh_path: _Path | None = test_path.parent / (
|
||||
test_path.stem + "__perfinstrumented" + test_path.suffix
|
||||
)
|
||||
beh_path.write_text(resolved_behavior, encoding="utf-8")
|
||||
perf_path: _Path | None = test_perf_path
|
||||
|
||||
test_perf_path.write_text(
|
||||
resolved_perf,
|
||||
encoding="utf-8",
|
||||
)
|
||||
if ok_beh and beh_src is not None:
|
||||
beh_path.write_text(beh_src, encoding="utf-8") # type: ignore[union-attr]
|
||||
else:
|
||||
beh_path = None
|
||||
if ok_perf and perf_src is not None:
|
||||
perf_path.write_text(perf_src, encoding="utf-8") # type: ignore[union-attr]
|
||||
else:
|
||||
perf_path = None
|
||||
|
||||
test_file_objects.append(
|
||||
TestFile(
|
||||
original_file_path=test_path,
|
||||
instrumented_behavior_file_path=beh_path,
|
||||
benchmarking_file_path=test_perf_path,
|
||||
benchmarking_file_path=perf_path,
|
||||
test_type=TestType.GENERATED_REGRESSION,
|
||||
),
|
||||
)
|
||||
|
|
@ -593,8 +598,6 @@ def review_and_repair_tests( # noqa: C901, PLR0913
|
|||
(
|
||||
tidx,
|
||||
gen_src,
|
||||
_beh,
|
||||
_perf,
|
||||
test_path,
|
||||
test_perf_path,
|
||||
) = entry
|
||||
|
|
@ -645,12 +648,9 @@ def review_and_repair_tests( # noqa: C901, PLR0913
|
|||
if repair_result is None:
|
||||
continue
|
||||
|
||||
repaired_gen, repaired_beh, repaired_perf = repair_result
|
||||
pending[pos] = (
|
||||
tidx,
|
||||
repaired_gen,
|
||||
repaired_beh,
|
||||
repaired_perf,
|
||||
repair_result,
|
||||
test_path,
|
||||
test_perf_path,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -45,10 +45,12 @@ def get_call_arguments(
|
|||
def node_in_call_position(
|
||||
node: ast.AST, call_positions: list[CodePosition]
|
||||
) -> bool:
|
||||
"""Return True if the AST node overlaps any of the given call positions."""
|
||||
# Reduce attribute lookup and localize call_positions
|
||||
# if not empty for a meaningful speedup.
|
||||
# Small optimizations for tight loop:
|
||||
"""Return True if the AST node overlaps any of the given call positions.
|
||||
|
||||
When *call_positions* is empty, every ``ast.Call`` node matches.
|
||||
"""
|
||||
if not call_positions:
|
||||
return isinstance(node, ast.Call)
|
||||
if isinstance(node, ast.Call):
|
||||
node_lineno = getattr(node, "lineno", None)
|
||||
node_col_offset = getattr(node, "col_offset", None)
|
||||
|
|
|
|||
|
|
@ -104,24 +104,21 @@ class TestgenPayload:
|
|||
def generate_regression_tests(
|
||||
client: AIClient,
|
||||
payload: TestgenPayload,
|
||||
) -> tuple[str, str, str, str | None] | None:
|
||||
) -> tuple[str, str | None] | None:
|
||||
"""
|
||||
Call the AI service ``/ai/testgen`` endpoint to generate regression tests.
|
||||
|
||||
Returns *(generated_tests, instrumented_behavior, instrumented_perf,
|
||||
raw_generated)* or ``None`` on failure.
|
||||
Returns *(generated_tests, raw_generated)* or ``None`` on failure.
|
||||
"""
|
||||
data = client.post("/testgen", payload.to_dict())
|
||||
|
||||
generated = data.get("generated_tests", "")
|
||||
behavior = data.get("instrumented_behavior_tests", "")
|
||||
perf = data.get("instrumented_perf_tests", "")
|
||||
raw = data.get("raw_generated_tests")
|
||||
|
||||
if not generated:
|
||||
return None
|
||||
|
||||
return (generated, behavior, perf, raw)
|
||||
return (generated, raw)
|
||||
|
||||
|
||||
def generate_tests( # noqa: PLR0913
|
||||
|
|
@ -139,12 +136,12 @@ def generate_tests( # noqa: PLR0913
|
|||
test_module_path: str,
|
||||
language_version: str,
|
||||
is_numerical_code: bool | None = None, # noqa: FBT001
|
||||
) -> tuple[str, str, str, str | None, Path, Path] | None:
|
||||
) -> tuple[str, str | None, Path, Path] | None:
|
||||
"""
|
||||
Generate regression tests for a function via the AI service.
|
||||
|
||||
Returns *(generated_source, behavior_source, perf_source, raw_source,
|
||||
test_path, test_perf_path)* or ``None`` on failure.
|
||||
Returns *(generated_source, raw_source, test_path, test_perf_path)*
|
||||
or ``None`` on failure.
|
||||
"""
|
||||
payload = TestgenPayload(
|
||||
source_code_being_tested=source_code_being_tested,
|
||||
|
|
@ -171,8 +168,8 @@ def generate_tests( # noqa: PLR0913
|
|||
if response is None:
|
||||
return None
|
||||
|
||||
generated, behavior, perf, raw = response
|
||||
return (generated, behavior, perf, raw, test_path, test_perf_path)
|
||||
generated, raw = response
|
||||
return (generated, raw, test_path, test_perf_path)
|
||||
|
||||
|
||||
def review_generated_tests(
|
||||
|
|
@ -195,23 +192,20 @@ def review_generated_tests(
|
|||
def repair_generated_tests(
|
||||
client: AIClient,
|
||||
payload: dict[str, Any],
|
||||
) -> tuple[str, str, str] | None:
|
||||
) -> str | None:
|
||||
"""
|
||||
Repair generated tests via the AI service.
|
||||
|
||||
Returns *(generated_tests, instrumented_behavior, instrumented_perf)*
|
||||
or ``None`` on failure.
|
||||
Returns the repaired test source or ``None`` on failure.
|
||||
"""
|
||||
try:
|
||||
data = client.post("/testgen_repair", payload)
|
||||
except (AIServiceError, AIServiceConnectionError):
|
||||
return None
|
||||
generated = data.get("generated_tests", "")
|
||||
behavior = data.get("instrumented_behavior_tests", "")
|
||||
perf = data.get("instrumented_perf_tests", "")
|
||||
if not generated:
|
||||
return None
|
||||
return (generated, behavior, perf)
|
||||
return generated
|
||||
|
||||
|
||||
class ModifyInspiredTests(ast.NodeTransformer):
|
||||
|
|
|
|||
|
|
@ -139,12 +139,12 @@ class TestNodeInCallPosition:
|
|||
positions = [CodePosition(line_no=99, col_no=99)]
|
||||
assert node_in_call_position(call_node, positions) is False
|
||||
|
||||
def test_empty_positions(self) -> None:
|
||||
"""Returns False when positions list is empty."""
|
||||
def test_empty_positions_matches_all(self) -> None:
|
||||
"""Returns True for any Call when positions list is empty."""
|
||||
code = "target_func()\n"
|
||||
tree = ast.parse(code)
|
||||
call_node = tree.body[0].value # type: ignore[attr-defined]
|
||||
assert node_in_call_position(call_node, []) is False
|
||||
assert node_in_call_position(call_node, []) is True
|
||||
|
||||
def test_multiple_positions_one_match(self) -> None:
|
||||
"""Returns True when one of several positions matches."""
|
||||
|
|
|
|||
|
|
@ -293,12 +293,10 @@ class TestGenerateRegressionTests:
|
|||
"""generate_regression_tests AI service call."""
|
||||
|
||||
def test_successful_response(self) -> None:
|
||||
"""Successful response returns tuple of 4 strings."""
|
||||
"""Successful response returns tuple of generated and raw sources."""
|
||||
client = make_mock_client()
|
||||
client.post.return_value = {
|
||||
"generated_tests": "test code",
|
||||
"instrumented_behavior_tests": "behavior code",
|
||||
"instrumented_perf_tests": "perf code",
|
||||
"raw_generated_tests": "raw code",
|
||||
}
|
||||
|
||||
|
|
@ -310,8 +308,8 @@ class TestGenerateRegressionTests:
|
|||
payload=payload,
|
||||
)
|
||||
assert result is not None
|
||||
assert 4 == len(result)
|
||||
assert "test code" == result[0]
|
||||
assert 2 == len(result)
|
||||
assert ("test code", "raw code") == result
|
||||
|
||||
def test_empty_generated_source_returns_none(self) -> None:
|
||||
"""Empty generated_test_source returns None."""
|
||||
|
|
@ -391,20 +389,17 @@ class TestRepairGeneratedTests:
|
|||
"""repair_generated_tests AI service call."""
|
||||
|
||||
def test_successful_response(self) -> None:
|
||||
"""Successful response returns tuple of 3 strings."""
|
||||
"""Successful response returns the repaired test source."""
|
||||
client = make_mock_client()
|
||||
client.post.return_value = {
|
||||
"generated_tests": "fixed tests",
|
||||
"instrumented_behavior_tests": "fixed behavior",
|
||||
"instrumented_perf_tests": "fixed perf",
|
||||
}
|
||||
|
||||
result = repair_generated_tests(
|
||||
client, {"test_source": "x", "trace_id": "t1"}
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert ("fixed tests", "fixed behavior", "fixed perf") == result
|
||||
assert "fixed tests" == result
|
||||
|
||||
def test_failure_returns_none(self) -> None:
|
||||
"""API error returns None."""
|
||||
|
|
@ -438,11 +433,9 @@ class TestGenerateTests:
|
|||
mock_regression: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Successful flow returns tuple of 6 items."""
|
||||
"""Successful flow returns tuple of 4 items."""
|
||||
mock_regression.return_value = (
|
||||
"generated",
|
||||
"behavior",
|
||||
"perf",
|
||||
"raw",
|
||||
)
|
||||
client = make_mock_client()
|
||||
|
|
@ -466,7 +459,11 @@ class TestGenerateTests:
|
|||
language_version="3.12.0",
|
||||
)
|
||||
assert result is not None
|
||||
assert 6 == len(result)
|
||||
assert 4 == len(result)
|
||||
assert "generated" == result[0]
|
||||
assert "raw" == result[1]
|
||||
assert test_path == result[2]
|
||||
assert test_perf_path == result[3]
|
||||
|
||||
@patch("codeflash_python.testing._testgen.generate_regression_tests")
|
||||
def test_none_from_api_returns_none(
|
||||
|
|
|
|||
Loading…
Reference in a new issue