Server-side instrumentation wrote return values to .bin files, which corrupted under concurrent pytest processes (interleaved records → UnicodeDecodeError). Client-side instrumentation writes to SQLite, which handles concurrent access safely. The client now ignores instrumented_behavior_tests and instrumented_perf_tests from the AI service response and instruments the plain generated_tests locally using inject_profiling_into_existing_test, the same path used for discovered existing tests.
494 lines
16 KiB
Python
494 lines
16 KiB
Python
"""Tests for _testgen — test generation and merging."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import textwrap
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import attrs
|
|
import pytest
|
|
|
|
from codeflash_core.exceptions import (
|
|
AIServiceConnectionError,
|
|
AIServiceError,
|
|
)
|
|
from codeflash_python._model import FunctionToOptimize
|
|
from codeflash_python.testing._testgen import (
|
|
GeneratedTests,
|
|
GeneratedTestsList,
|
|
ModifyInspiredTests,
|
|
TestgenPayload,
|
|
delete_multiple_if_name_main,
|
|
generate_regression_tests,
|
|
generate_tests,
|
|
merge_unit_tests,
|
|
repair_generated_tests,
|
|
review_generated_tests,
|
|
)
|
|
|
|
|
|
def make_function(
|
|
name: str = "target_func",
|
|
file_path: str = "module.py",
|
|
) -> FunctionToOptimize:
|
|
"""Create a FunctionToOptimize for testing."""
|
|
return FunctionToOptimize(
|
|
function_name=name,
|
|
file_path=Path(file_path),
|
|
)
|
|
|
|
|
|
def make_mock_client() -> MagicMock:
|
|
"""Create a mock AIClient."""
|
|
return MagicMock()
|
|
|
|
|
|
def make_payload(**overrides: object) -> TestgenPayload:
|
|
"""Create a TestgenPayload with sensible defaults."""
|
|
defaults: dict[str, object] = {
|
|
"source_code_being_tested": "def foo(): pass",
|
|
"function_to_optimize": make_function().to_dict(),
|
|
"helper_function_names": [],
|
|
"module_path": "module",
|
|
"test_module_path": "test_module",
|
|
"test_framework": "pytest",
|
|
"test_timeout": 30,
|
|
"trace_id": "trace-123",
|
|
"test_index": 0,
|
|
"language_version": "3.12.0",
|
|
}
|
|
defaults.update(overrides)
|
|
return TestgenPayload(**defaults) # type: ignore[arg-type]
|
|
|
|
|
|
class TestGeneratedTests:
|
|
"""GeneratedTests frozen model."""
|
|
|
|
def test_create_with_all_fields(self, tmp_path: Path) -> None:
|
|
"""All fields are stored and accessible."""
|
|
gt = GeneratedTests(
|
|
generated_original_test_source="original",
|
|
instrumented_behavior_test_source="behavior",
|
|
instrumented_perf_test_source="perf",
|
|
behavior_file_path=tmp_path / "behavior.py",
|
|
perf_file_path=tmp_path / "perf.py",
|
|
raw_generated_test_source="raw",
|
|
)
|
|
assert "original" == gt.generated_original_test_source
|
|
assert "behavior" == gt.instrumented_behavior_test_source
|
|
assert "perf" == gt.instrumented_perf_test_source
|
|
assert tmp_path / "behavior.py" == gt.behavior_file_path
|
|
assert tmp_path / "perf.py" == gt.perf_file_path
|
|
assert "raw" == gt.raw_generated_test_source
|
|
|
|
def test_raw_source_defaults_to_none(self, tmp_path: Path) -> None:
|
|
"""raw_generated_test_source defaults to None."""
|
|
gt = GeneratedTests(
|
|
generated_original_test_source="original",
|
|
instrumented_behavior_test_source="behavior",
|
|
instrumented_perf_test_source="perf",
|
|
behavior_file_path=tmp_path / "behavior.py",
|
|
perf_file_path=tmp_path / "perf.py",
|
|
)
|
|
assert gt.raw_generated_test_source is None
|
|
|
|
def test_frozen(self, tmp_path: Path) -> None:
|
|
"""Assigning to a field raises FrozenInstanceError."""
|
|
gt = GeneratedTests(
|
|
generated_original_test_source="original",
|
|
instrumented_behavior_test_source="behavior",
|
|
instrumented_perf_test_source="perf",
|
|
behavior_file_path=tmp_path / "behavior.py",
|
|
perf_file_path=tmp_path / "perf.py",
|
|
)
|
|
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
|
gt.generated_original_test_source = "changed" # type: ignore[misc]
|
|
|
|
|
|
class TestGeneratedTestsList:
|
|
"""GeneratedTestsList frozen collection."""
|
|
|
|
def test_default_empty(self) -> None:
|
|
"""Default generated_tests is an empty tuple."""
|
|
gtl = GeneratedTestsList()
|
|
assert () == gtl.generated_tests
|
|
|
|
def test_with_items(self, tmp_path: Path) -> None:
|
|
"""Stores a tuple of GeneratedTests."""
|
|
gt = GeneratedTests(
|
|
generated_original_test_source="orig",
|
|
instrumented_behavior_test_source="beh",
|
|
instrumented_perf_test_source="perf",
|
|
behavior_file_path=tmp_path / "b.py",
|
|
perf_file_path=tmp_path / "p.py",
|
|
)
|
|
gtl = GeneratedTestsList(generated_tests=(gt,))
|
|
assert 1 == len(gtl.generated_tests)
|
|
assert gt is gtl.generated_tests[0]
|
|
|
|
|
|
class TestDeleteMultipleIfNameMain:
|
|
"""delete_multiple_if_name_main AST cleanup."""
|
|
|
|
def test_zero_blocks_unchanged(self) -> None:
|
|
"""Body is unchanged when no if __name__ blocks exist."""
|
|
code = textwrap.dedent("""\
|
|
x = 1
|
|
y = 2
|
|
""")
|
|
tree = ast.parse(code)
|
|
original_len = len(tree.body)
|
|
result = delete_multiple_if_name_main(tree)
|
|
assert original_len == len(result.body)
|
|
|
|
def test_one_block_unchanged(self) -> None:
|
|
"""Body is unchanged when exactly one if __name__ block exists."""
|
|
code = textwrap.dedent("""\
|
|
x = 1
|
|
if __name__ == "__main__":
|
|
pass
|
|
""")
|
|
tree = ast.parse(code)
|
|
original_len = len(tree.body)
|
|
result = delete_multiple_if_name_main(tree)
|
|
assert original_len == len(result.body)
|
|
|
|
def test_two_blocks_keeps_last(self) -> None:
|
|
"""First if __name__ block is removed, last is kept."""
|
|
code = textwrap.dedent("""\
|
|
if __name__ == "__main__":
|
|
x = 1
|
|
y = 2
|
|
if __name__ == "__main__":
|
|
z = 3
|
|
""")
|
|
tree = ast.parse(code)
|
|
result = delete_multiple_if_name_main(tree)
|
|
# Should have y = 2 and the last if __name__ block
|
|
if_name_blocks = [
|
|
node
|
|
for node in result.body
|
|
if isinstance(node, ast.If)
|
|
and isinstance(node.test, ast.Compare)
|
|
and isinstance(node.test.left, ast.Name)
|
|
and node.test.left.id == "__name__"
|
|
]
|
|
assert 1 == len(if_name_blocks)
|
|
|
|
def test_three_blocks_only_last_kept(self) -> None:
|
|
"""With three blocks, only the last is kept."""
|
|
code = textwrap.dedent("""\
|
|
if __name__ == "__main__":
|
|
a = 1
|
|
if __name__ == "__main__":
|
|
b = 2
|
|
if __name__ == "__main__":
|
|
c = 3
|
|
""")
|
|
tree = ast.parse(code)
|
|
result = delete_multiple_if_name_main(tree)
|
|
if_name_blocks = [
|
|
node
|
|
for node in result.body
|
|
if isinstance(node, ast.If)
|
|
and isinstance(node.test, ast.Compare)
|
|
and isinstance(node.test.left, ast.Name)
|
|
and node.test.left.id == "__name__"
|
|
]
|
|
assert 1 == len(if_name_blocks)
|
|
# The kept block should contain c = 3
|
|
kept_body = if_name_blocks[0].body
|
|
assert any(
|
|
isinstance(n, ast.Assign)
|
|
and isinstance(n.targets[0], ast.Name)
|
|
and n.targets[0].id == "c"
|
|
for n in kept_body
|
|
)
|
|
|
|
|
|
class TestModifyInspiredTests:
|
|
"""ModifyInspiredTests AST transformer."""
|
|
|
|
def test_extracts_import_nodes(self) -> None:
|
|
"""Import nodes are extracted to import_list."""
|
|
code = textwrap.dedent("""\
|
|
import os
|
|
import sys
|
|
x = 1
|
|
""")
|
|
tree = ast.parse(code)
|
|
import_list: list[ast.stmt] = []
|
|
transformer = ModifyInspiredTests(import_list)
|
|
transformer.visit(tree)
|
|
assert 2 == len(import_list)
|
|
assert all(isinstance(n, ast.Import) for n in import_list)
|
|
|
|
def test_extracts_import_from_nodes(self) -> None:
|
|
"""ImportFrom nodes are extracted to import_list."""
|
|
code = textwrap.dedent("""\
|
|
from os.path import join
|
|
from sys import argv
|
|
x = 1
|
|
""")
|
|
tree = ast.parse(code)
|
|
import_list: list[ast.stmt] = []
|
|
transformer = ModifyInspiredTests(import_list)
|
|
transformer.visit(tree)
|
|
assert 2 == len(import_list)
|
|
assert all(isinstance(n, ast.ImportFrom) for n in import_list)
|
|
|
|
class TestMergeUnitTests:
|
|
"""merge_unit_tests test merging."""
|
|
|
|
def test_pytest_inspired_suffix(self) -> None:
|
|
"""With pytest, generated test functions get __inspired suffix."""
|
|
original = textwrap.dedent("""\
|
|
def test_foo():
|
|
assert True
|
|
""")
|
|
generated = textwrap.dedent("""\
|
|
def test_foo():
|
|
assert 1 == 1
|
|
""")
|
|
result = merge_unit_tests(original, generated)
|
|
assert "__inspired" in result
|
|
|
|
def test_imports_from_generated_prepended(self) -> None:
|
|
"""Imports from generated tests are included in merged source."""
|
|
original = textwrap.dedent("""\
|
|
def test_foo():
|
|
assert True
|
|
""")
|
|
generated = textwrap.dedent("""\
|
|
import math
|
|
def test_bar():
|
|
assert math.pi > 3
|
|
""")
|
|
result = merge_unit_tests(original, generated)
|
|
assert "import math" in result
|
|
|
|
def test_syntax_error_returns_original(self) -> None:
|
|
"""Syntax errors in generated tests return original unchanged."""
|
|
original = textwrap.dedent("""\
|
|
def test_foo():
|
|
assert True
|
|
""")
|
|
generated = "def test_bar(\n not valid !!!"
|
|
result = merge_unit_tests(original, generated)
|
|
assert "def test_foo" in result
|
|
|
|
def test_empty_generated(self) -> None:
|
|
"""Empty generated tests return original or merged cleanly."""
|
|
original = textwrap.dedent("""\
|
|
def test_foo():
|
|
assert True
|
|
""")
|
|
result = merge_unit_tests(original, "")
|
|
assert "def test_foo" in result
|
|
|
|
|
|
class TestGenerateRegressionTests:
|
|
"""generate_regression_tests AI service call."""
|
|
|
|
def test_successful_response(self) -> None:
|
|
"""Successful response returns tuple of generated and raw sources."""
|
|
client = make_mock_client()
|
|
client.post.return_value = {
|
|
"generated_tests": "test code",
|
|
"raw_generated_tests": "raw code",
|
|
}
|
|
|
|
payload = make_payload(
|
|
helper_function_names=["helper1"],
|
|
)
|
|
result = generate_regression_tests(
|
|
client=client,
|
|
payload=payload,
|
|
)
|
|
assert result is not None
|
|
assert 2 == len(result)
|
|
assert ("test code", "raw code") == result
|
|
|
|
def test_empty_generated_source_returns_none(self) -> None:
|
|
"""Empty generated_test_source returns None."""
|
|
client = make_mock_client()
|
|
client.post.return_value = {
|
|
"generated_test_source": "",
|
|
"instrumented_test_source_behavior": "",
|
|
"instrumented_test_source_perf": "",
|
|
}
|
|
|
|
result = generate_regression_tests(
|
|
client=client,
|
|
payload=make_payload(),
|
|
)
|
|
assert result is None
|
|
|
|
def test_http_error_raises_ai_service_error(self) -> None:
|
|
"""HTTP error raises AIServiceError."""
|
|
client = make_mock_client()
|
|
client.post.side_effect = AIServiceError(500, "Internal Server Error")
|
|
|
|
with pytest.raises(AIServiceError):
|
|
generate_regression_tests(
|
|
client=client,
|
|
payload=make_payload(),
|
|
)
|
|
|
|
def test_connection_error_raises_connection_error(self) -> None:
|
|
"""Connection error raises AIServiceConnectionError."""
|
|
client = make_mock_client()
|
|
client.post.side_effect = AIServiceConnectionError(
|
|
"Connection refused",
|
|
)
|
|
|
|
with pytest.raises(AIServiceConnectionError):
|
|
generate_regression_tests(
|
|
client=client,
|
|
payload=make_payload(),
|
|
)
|
|
|
|
|
|
class TestReviewGeneratedTests:
|
|
"""review_generated_tests AI service call."""
|
|
|
|
def test_successful_response(self) -> None:
|
|
"""Successful response returns list of review dicts."""
|
|
client = make_mock_client()
|
|
client.post.return_value = {
|
|
"reviews": [
|
|
{
|
|
"test_index": 0,
|
|
"functions": [
|
|
{"function_name": "test_foo", "reason": "bad"},
|
|
],
|
|
},
|
|
],
|
|
}
|
|
|
|
result = review_generated_tests(
|
|
client, {"tests": [], "trace_id": "t1"}
|
|
)
|
|
|
|
assert 1 == len(result)
|
|
assert 0 == result[0]["test_index"]
|
|
|
|
def test_failure_returns_empty_list(self) -> None:
|
|
"""API error returns empty list."""
|
|
client = make_mock_client()
|
|
client.post.side_effect = AIServiceError(500, "fail")
|
|
|
|
result = review_generated_tests(client, {"trace_id": "t1"})
|
|
|
|
assert [] == result
|
|
|
|
|
|
class TestRepairGeneratedTests:
|
|
"""repair_generated_tests AI service call."""
|
|
|
|
def test_successful_response(self) -> None:
|
|
"""Successful response returns the repaired test source."""
|
|
client = make_mock_client()
|
|
client.post.return_value = {
|
|
"generated_tests": "fixed tests",
|
|
}
|
|
|
|
result = repair_generated_tests(
|
|
client, {"test_source": "x", "trace_id": "t1"}
|
|
)
|
|
|
|
assert "fixed tests" == result
|
|
|
|
def test_failure_returns_none(self) -> None:
|
|
"""API error returns None."""
|
|
client = make_mock_client()
|
|
client.post.side_effect = AIServiceError(500, "fail")
|
|
|
|
result = repair_generated_tests(client, {"trace_id": "t1"})
|
|
|
|
assert result is None
|
|
|
|
def test_empty_generated_returns_none(self) -> None:
|
|
"""Empty generated_tests returns None."""
|
|
client = make_mock_client()
|
|
client.post.return_value = {
|
|
"generated_tests": "",
|
|
"instrumented_behavior_tests": "beh",
|
|
"instrumented_perf_tests": "perf",
|
|
}
|
|
|
|
result = repair_generated_tests(client, {"trace_id": "t1"})
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestGenerateTests:
|
|
"""generate_tests orchestration."""
|
|
|
|
@patch("codeflash_python.testing._testgen.generate_regression_tests")
|
|
def test_successful_flow(
|
|
self,
|
|
mock_regression: MagicMock,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Successful flow returns tuple of 4 items."""
|
|
mock_regression.return_value = (
|
|
"generated",
|
|
"raw",
|
|
)
|
|
client = make_mock_client()
|
|
func = make_function()
|
|
test_path = tmp_path / "tests" / "test_behavior.py"
|
|
test_perf_path = tmp_path / "tests" / "test_perf.py"
|
|
|
|
result = generate_tests(
|
|
client=client,
|
|
source_code_being_tested="def foo(): pass",
|
|
function_to_optimize=func,
|
|
helper_function_names=[],
|
|
module_path="module",
|
|
test_framework="pytest",
|
|
test_timeout=30,
|
|
trace_id="trace-123",
|
|
test_index=0,
|
|
test_path=test_path,
|
|
test_perf_path=test_perf_path,
|
|
test_module_path="tests.test_behavior",
|
|
language_version="3.12.0",
|
|
)
|
|
assert result is not None
|
|
assert 4 == len(result)
|
|
assert "generated" == result[0]
|
|
assert "raw" == result[1]
|
|
assert test_path == result[2]
|
|
assert test_perf_path == result[3]
|
|
|
|
@patch("codeflash_python.testing._testgen.generate_regression_tests")
|
|
def test_none_from_api_returns_none(
|
|
self,
|
|
mock_regression: MagicMock,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""None response from regression tests returns None."""
|
|
mock_regression.return_value = None
|
|
client = make_mock_client()
|
|
func = make_function()
|
|
|
|
result = generate_tests(
|
|
client=client,
|
|
source_code_being_tested="def foo(): pass",
|
|
function_to_optimize=func,
|
|
helper_function_names=[],
|
|
module_path="module",
|
|
test_framework="pytest",
|
|
test_timeout=30,
|
|
trace_id="trace-123",
|
|
test_index=0,
|
|
test_path=tmp_path / "test_b.py",
|
|
test_perf_path=tmp_path / "test_p.py",
|
|
test_module_path="test_b",
|
|
language_version="3.12.0",
|
|
)
|
|
assert result is None
|