codeflash-agent/packages/codeflash-python/tests/test_testgen.py
Kevin Turcios 434e888571 Move AI-generated test instrumentation from server-side to client-side
Server-side instrumentation wrote return values to .bin files, which
corrupted under concurrent pytest processes (interleaved records →
UnicodeDecodeError). Client-side instrumentation writes to SQLite,
which handles concurrent access safely.

The client now ignores instrumented_behavior_tests and
instrumented_perf_tests from the AI service response and instruments
the plain generated_tests locally using inject_profiling_into_existing_test,
the same path used for discovered existing tests.
2026-04-21 07:38:48 -05:00

494 lines
16 KiB
Python

"""Tests for _testgen — test generation and merging."""
from __future__ import annotations
import ast
import textwrap
from pathlib import Path
from unittest.mock import MagicMock, patch
import attrs
import pytest
from codeflash_core.exceptions import (
AIServiceConnectionError,
AIServiceError,
)
from codeflash_python._model import FunctionToOptimize
from codeflash_python.testing._testgen import (
GeneratedTests,
GeneratedTestsList,
ModifyInspiredTests,
TestgenPayload,
delete_multiple_if_name_main,
generate_regression_tests,
generate_tests,
merge_unit_tests,
repair_generated_tests,
review_generated_tests,
)
def make_function(
name: str = "target_func",
file_path: str = "module.py",
) -> FunctionToOptimize:
"""Create a FunctionToOptimize for testing."""
return FunctionToOptimize(
function_name=name,
file_path=Path(file_path),
)
def make_mock_client() -> MagicMock:
"""Create a mock AIClient."""
return MagicMock()
def make_payload(**overrides: object) -> TestgenPayload:
"""Create a TestgenPayload with sensible defaults."""
defaults: dict[str, object] = {
"source_code_being_tested": "def foo(): pass",
"function_to_optimize": make_function().to_dict(),
"helper_function_names": [],
"module_path": "module",
"test_module_path": "test_module",
"test_framework": "pytest",
"test_timeout": 30,
"trace_id": "trace-123",
"test_index": 0,
"language_version": "3.12.0",
}
defaults.update(overrides)
return TestgenPayload(**defaults) # type: ignore[arg-type]
class TestGeneratedTests:
"""GeneratedTests frozen model."""
def test_create_with_all_fields(self, tmp_path: Path) -> None:
"""All fields are stored and accessible."""
gt = GeneratedTests(
generated_original_test_source="original",
instrumented_behavior_test_source="behavior",
instrumented_perf_test_source="perf",
behavior_file_path=tmp_path / "behavior.py",
perf_file_path=tmp_path / "perf.py",
raw_generated_test_source="raw",
)
assert "original" == gt.generated_original_test_source
assert "behavior" == gt.instrumented_behavior_test_source
assert "perf" == gt.instrumented_perf_test_source
assert tmp_path / "behavior.py" == gt.behavior_file_path
assert tmp_path / "perf.py" == gt.perf_file_path
assert "raw" == gt.raw_generated_test_source
def test_raw_source_defaults_to_none(self, tmp_path: Path) -> None:
"""raw_generated_test_source defaults to None."""
gt = GeneratedTests(
generated_original_test_source="original",
instrumented_behavior_test_source="behavior",
instrumented_perf_test_source="perf",
behavior_file_path=tmp_path / "behavior.py",
perf_file_path=tmp_path / "perf.py",
)
assert gt.raw_generated_test_source is None
def test_frozen(self, tmp_path: Path) -> None:
"""Assigning to a field raises FrozenInstanceError."""
gt = GeneratedTests(
generated_original_test_source="original",
instrumented_behavior_test_source="behavior",
instrumented_perf_test_source="perf",
behavior_file_path=tmp_path / "behavior.py",
perf_file_path=tmp_path / "perf.py",
)
with pytest.raises(attrs.exceptions.FrozenInstanceError):
gt.generated_original_test_source = "changed" # type: ignore[misc]
class TestGeneratedTestsList:
"""GeneratedTestsList frozen collection."""
def test_default_empty(self) -> None:
"""Default generated_tests is an empty tuple."""
gtl = GeneratedTestsList()
assert () == gtl.generated_tests
def test_with_items(self, tmp_path: Path) -> None:
"""Stores a tuple of GeneratedTests."""
gt = GeneratedTests(
generated_original_test_source="orig",
instrumented_behavior_test_source="beh",
instrumented_perf_test_source="perf",
behavior_file_path=tmp_path / "b.py",
perf_file_path=tmp_path / "p.py",
)
gtl = GeneratedTestsList(generated_tests=(gt,))
assert 1 == len(gtl.generated_tests)
assert gt is gtl.generated_tests[0]
class TestDeleteMultipleIfNameMain:
"""delete_multiple_if_name_main AST cleanup."""
def test_zero_blocks_unchanged(self) -> None:
"""Body is unchanged when no if __name__ blocks exist."""
code = textwrap.dedent("""\
x = 1
y = 2
""")
tree = ast.parse(code)
original_len = len(tree.body)
result = delete_multiple_if_name_main(tree)
assert original_len == len(result.body)
def test_one_block_unchanged(self) -> None:
"""Body is unchanged when exactly one if __name__ block exists."""
code = textwrap.dedent("""\
x = 1
if __name__ == "__main__":
pass
""")
tree = ast.parse(code)
original_len = len(tree.body)
result = delete_multiple_if_name_main(tree)
assert original_len == len(result.body)
def test_two_blocks_keeps_last(self) -> None:
"""First if __name__ block is removed, last is kept."""
code = textwrap.dedent("""\
if __name__ == "__main__":
x = 1
y = 2
if __name__ == "__main__":
z = 3
""")
tree = ast.parse(code)
result = delete_multiple_if_name_main(tree)
# Should have y = 2 and the last if __name__ block
if_name_blocks = [
node
for node in result.body
if isinstance(node, ast.If)
and isinstance(node.test, ast.Compare)
and isinstance(node.test.left, ast.Name)
and node.test.left.id == "__name__"
]
assert 1 == len(if_name_blocks)
def test_three_blocks_only_last_kept(self) -> None:
"""With three blocks, only the last is kept."""
code = textwrap.dedent("""\
if __name__ == "__main__":
a = 1
if __name__ == "__main__":
b = 2
if __name__ == "__main__":
c = 3
""")
tree = ast.parse(code)
result = delete_multiple_if_name_main(tree)
if_name_blocks = [
node
for node in result.body
if isinstance(node, ast.If)
and isinstance(node.test, ast.Compare)
and isinstance(node.test.left, ast.Name)
and node.test.left.id == "__name__"
]
assert 1 == len(if_name_blocks)
# The kept block should contain c = 3
kept_body = if_name_blocks[0].body
assert any(
isinstance(n, ast.Assign)
and isinstance(n.targets[0], ast.Name)
and n.targets[0].id == "c"
for n in kept_body
)
class TestModifyInspiredTests:
"""ModifyInspiredTests AST transformer."""
def test_extracts_import_nodes(self) -> None:
"""Import nodes are extracted to import_list."""
code = textwrap.dedent("""\
import os
import sys
x = 1
""")
tree = ast.parse(code)
import_list: list[ast.stmt] = []
transformer = ModifyInspiredTests(import_list)
transformer.visit(tree)
assert 2 == len(import_list)
assert all(isinstance(n, ast.Import) for n in import_list)
def test_extracts_import_from_nodes(self) -> None:
"""ImportFrom nodes are extracted to import_list."""
code = textwrap.dedent("""\
from os.path import join
from sys import argv
x = 1
""")
tree = ast.parse(code)
import_list: list[ast.stmt] = []
transformer = ModifyInspiredTests(import_list)
transformer.visit(tree)
assert 2 == len(import_list)
assert all(isinstance(n, ast.ImportFrom) for n in import_list)
class TestMergeUnitTests:
"""merge_unit_tests test merging."""
def test_pytest_inspired_suffix(self) -> None:
"""With pytest, generated test functions get __inspired suffix."""
original = textwrap.dedent("""\
def test_foo():
assert True
""")
generated = textwrap.dedent("""\
def test_foo():
assert 1 == 1
""")
result = merge_unit_tests(original, generated)
assert "__inspired" in result
def test_imports_from_generated_prepended(self) -> None:
"""Imports from generated tests are included in merged source."""
original = textwrap.dedent("""\
def test_foo():
assert True
""")
generated = textwrap.dedent("""\
import math
def test_bar():
assert math.pi > 3
""")
result = merge_unit_tests(original, generated)
assert "import math" in result
def test_syntax_error_returns_original(self) -> None:
"""Syntax errors in generated tests return original unchanged."""
original = textwrap.dedent("""\
def test_foo():
assert True
""")
generated = "def test_bar(\n not valid !!!"
result = merge_unit_tests(original, generated)
assert "def test_foo" in result
def test_empty_generated(self) -> None:
"""Empty generated tests return original or merged cleanly."""
original = textwrap.dedent("""\
def test_foo():
assert True
""")
result = merge_unit_tests(original, "")
assert "def test_foo" in result
class TestGenerateRegressionTests:
"""generate_regression_tests AI service call."""
def test_successful_response(self) -> None:
"""Successful response returns tuple of generated and raw sources."""
client = make_mock_client()
client.post.return_value = {
"generated_tests": "test code",
"raw_generated_tests": "raw code",
}
payload = make_payload(
helper_function_names=["helper1"],
)
result = generate_regression_tests(
client=client,
payload=payload,
)
assert result is not None
assert 2 == len(result)
assert ("test code", "raw code") == result
def test_empty_generated_source_returns_none(self) -> None:
"""Empty generated_test_source returns None."""
client = make_mock_client()
client.post.return_value = {
"generated_test_source": "",
"instrumented_test_source_behavior": "",
"instrumented_test_source_perf": "",
}
result = generate_regression_tests(
client=client,
payload=make_payload(),
)
assert result is None
def test_http_error_raises_ai_service_error(self) -> None:
"""HTTP error raises AIServiceError."""
client = make_mock_client()
client.post.side_effect = AIServiceError(500, "Internal Server Error")
with pytest.raises(AIServiceError):
generate_regression_tests(
client=client,
payload=make_payload(),
)
def test_connection_error_raises_connection_error(self) -> None:
"""Connection error raises AIServiceConnectionError."""
client = make_mock_client()
client.post.side_effect = AIServiceConnectionError(
"Connection refused",
)
with pytest.raises(AIServiceConnectionError):
generate_regression_tests(
client=client,
payload=make_payload(),
)
class TestReviewGeneratedTests:
"""review_generated_tests AI service call."""
def test_successful_response(self) -> None:
"""Successful response returns list of review dicts."""
client = make_mock_client()
client.post.return_value = {
"reviews": [
{
"test_index": 0,
"functions": [
{"function_name": "test_foo", "reason": "bad"},
],
},
],
}
result = review_generated_tests(
client, {"tests": [], "trace_id": "t1"}
)
assert 1 == len(result)
assert 0 == result[0]["test_index"]
def test_failure_returns_empty_list(self) -> None:
"""API error returns empty list."""
client = make_mock_client()
client.post.side_effect = AIServiceError(500, "fail")
result = review_generated_tests(client, {"trace_id": "t1"})
assert [] == result
class TestRepairGeneratedTests:
"""repair_generated_tests AI service call."""
def test_successful_response(self) -> None:
"""Successful response returns the repaired test source."""
client = make_mock_client()
client.post.return_value = {
"generated_tests": "fixed tests",
}
result = repair_generated_tests(
client, {"test_source": "x", "trace_id": "t1"}
)
assert "fixed tests" == result
def test_failure_returns_none(self) -> None:
"""API error returns None."""
client = make_mock_client()
client.post.side_effect = AIServiceError(500, "fail")
result = repair_generated_tests(client, {"trace_id": "t1"})
assert result is None
def test_empty_generated_returns_none(self) -> None:
"""Empty generated_tests returns None."""
client = make_mock_client()
client.post.return_value = {
"generated_tests": "",
"instrumented_behavior_tests": "beh",
"instrumented_perf_tests": "perf",
}
result = repair_generated_tests(client, {"trace_id": "t1"})
assert result is None
class TestGenerateTests:
"""generate_tests orchestration."""
@patch("codeflash_python.testing._testgen.generate_regression_tests")
def test_successful_flow(
self,
mock_regression: MagicMock,
tmp_path: Path,
) -> None:
"""Successful flow returns tuple of 4 items."""
mock_regression.return_value = (
"generated",
"raw",
)
client = make_mock_client()
func = make_function()
test_path = tmp_path / "tests" / "test_behavior.py"
test_perf_path = tmp_path / "tests" / "test_perf.py"
result = generate_tests(
client=client,
source_code_being_tested="def foo(): pass",
function_to_optimize=func,
helper_function_names=[],
module_path="module",
test_framework="pytest",
test_timeout=30,
trace_id="trace-123",
test_index=0,
test_path=test_path,
test_perf_path=test_perf_path,
test_module_path="tests.test_behavior",
language_version="3.12.0",
)
assert result is not None
assert 4 == len(result)
assert "generated" == result[0]
assert "raw" == result[1]
assert test_path == result[2]
assert test_perf_path == result[3]
@patch("codeflash_python.testing._testgen.generate_regression_tests")
def test_none_from_api_returns_none(
self,
mock_regression: MagicMock,
tmp_path: Path,
) -> None:
"""None response from regression tests returns None."""
mock_regression.return_value = None
client = make_mock_client()
func = make_function()
result = generate_tests(
client=client,
source_code_being_tested="def foo(): pass",
function_to_optimize=func,
helper_function_names=[],
module_path="module",
test_framework="pytest",
test_timeout=30,
trace_id="trace-123",
test_index=0,
test_path=tmp_path / "test_b.py",
test_perf_path=tmp_path / "test_p.py",
test_module_path="test_b",
language_version="3.12.0",
)
assert result is None