Achieve 100% test coverage for testgen module

Add 15 new tests covering all previously uncovered paths:
- _validate.py: regex class splitting, trailing blank stripping,
  repair preamble edge cases (empty during iteration, lineno=None,
  out-of-range index, max attempts exhausted), AST gap/decorator paths
- _generate.py: multi-context ellipsis detection, extract_code_block
  returning None, no test functions after validation
- _review_router.py: non-dict/non-list JSON in review verdicts

Mark 2 provably unreachable defensive lines with pragma: no cover.
This commit is contained in:
Kevin Turcios 2026-04-22 23:10:27 -05:00
parent 92c5fd7c74
commit 758da2592f
4 changed files with 724 additions and 5 deletions

View file

@ -200,7 +200,7 @@ def _split_code_with_ast(
line.strip() and not line.strip().startswith("#")
for line in gap_lines
):
preamble_lines.extend(gap_lines)
preamble_lines.extend(gap_lines) # pragma: no cover
test_code = "".join(lines[start_line:end_line])
test_functions.append(test_code)
@ -279,7 +279,7 @@ def _split_code_with_regex(
for start_line, end_line in test_boundaries:
test_code = "".join(lines[start_line:end_line]).rstrip()
test_lines = test_code.splitlines(keepends=True)
while test_lines and not test_lines[-1].strip():
while test_lines and not test_lines[-1].strip(): # pragma: no cover
test_lines.pop()
if test_lines:
test_functions.append("".join(test_lines).rstrip())

View file

@ -768,6 +768,39 @@ class TestAdaptiveIntegration:
class TestTestgenIntegration:
"""Integration tests for POST /ai/testgen."""
@pytest.mark.asyncio
async def test_success(
self,
client: httpx.AsyncClient,
mock_llm_client: MagicMock,
) -> None:
"""
Successful generation returns only generated_tests.
"""
mock_llm_client.call = AsyncMock(
return_value=_stub_llm_response(
"```python\ndef test_f():\n assert True\n```"
),
)
resp = await client.post(
"/ai/testgen",
json={
"trace_id": make_trace_id(),
"source_code_being_tested": "def f(): return 1",
"function_name": "f",
"module_path": "src/main.py",
"test_module_path": "tests/test_main.py",
"python_version": "3.12.1",
},
)
assert 200 == resp.status_code
data = resp.json()
assert "test_f" in data["generated_tests"]
assert "instrumented_behavior_tests" not in data
assert "instrumented_perf_tests" not in data
@pytest.mark.asyncio
async def test_missing_python_version_returns_400(
self, client: httpx.AsyncClient
@ -788,6 +821,103 @@ class TestTestgenIntegration:
assert 400 == resp.status_code
@pytest.mark.asyncio
async def test_invalid_trace_id(self, client: httpx.AsyncClient) -> None:
"""
Invalid trace_id returns 400.
"""
resp = await client.post(
"/ai/testgen",
json={
"trace_id": "bad-id",
"source_code_being_tested": "def f(): pass",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"python_version": "3.12.0",
},
)
assert 400 == resp.status_code
@pytest.mark.asyncio
async def test_non_python_returns_400(
self, client: httpx.AsyncClient
) -> None:
"""
Non-Python language returns 400.
"""
resp = await client.post(
"/ai/testgen",
json={
"trace_id": make_trace_id(),
"source_code_being_tested": "fn f() {}",
"function_name": "f",
"module_path": "m.rs",
"test_module_path": "t.rs",
"language": "rust",
"python_version": "3.12.0",
},
)
assert 400 == resp.status_code
@pytest.mark.asyncio
async def test_llm_failure_returns_500(
self,
client: httpx.AsyncClient,
mock_llm_client: MagicMock,
) -> None:
"""
LLM exception returns 500.
"""
mock_llm_client.call = AsyncMock(
side_effect=RuntimeError("LLM down"),
)
resp = await client.post(
"/ai/testgen",
json={
"trace_id": make_trace_id(),
"source_code_being_tested": "def f(): return 1",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"python_version": "3.12.0",
},
)
assert 500 == resp.status_code
@pytest.mark.asyncio
async def test_no_code_block_returns_422(
self,
client: httpx.AsyncClient,
mock_llm_client: MagicMock,
) -> None:
"""
LLM response without code block returns 422.
"""
mock_llm_client.call = AsyncMock(
return_value=_stub_llm_response(
"Sorry, I cannot generate tests."
),
)
resp = await client.post(
"/ai/testgen",
json={
"trace_id": make_trace_id(),
"source_code_being_tested": "def f(): return 1",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"python_version": "3.12.0",
},
)
assert 422 == resp.status_code
# ── Unknown route ──────────────────────────────────────────────

View file

@ -2,7 +2,7 @@
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -12,12 +12,17 @@ from codeflash_api.testgen._generate import (
_parse_and_validate_llm_output,
_parse_python_version,
build_prompt,
generate_tests,
select_model_for_test,
)
from codeflash_api.testgen._validate import (
CodeValidationError,
_find_decorator_start_line,
_inherits_from_testcase,
_repair_preamble,
_split_code_into_parts,
_split_code_with_ast,
_split_code_with_regex,
_validate_tests_individually,
has_test_functions,
validate_testgen_code,
@ -210,6 +215,115 @@ class TestHasTestFunctions:
tree = ast.parse(code)
assert has_test_functions(tree)
def test_async_non_test_with_nested_test(self) -> None:
"""
Async non-test function containing a test function is detected.
"""
import ast
code = (
"async def helper():\n pass\n\n"
"def test_real():\n assert True\n"
)
tree = ast.parse(code)
assert has_test_functions(tree)
def test_only_async_non_test_returns_false(self) -> None:
"""
Module with only async non-test functions returns False.
"""
import ast
tree = ast.parse("async def fetch(): pass")
assert not has_test_functions(tree)
class TestInheritsFromTestcase:
"""Tests for _inherits_from_testcase."""
def test_direct_testcase_base(self) -> None:
"""
Class inheriting from TestCase is detected.
"""
import ast
code = "class MyTest(TestCase): pass"
tree = ast.parse(code)
cls = tree.body[0]
assert _inherits_from_testcase(cls)
def test_attribute_testcase_base(self) -> None:
"""
Class inheriting from unittest.TestCase is detected.
"""
import ast
code = "class MyTest(unittest.TestCase): pass"
tree = ast.parse(code)
cls = tree.body[0]
assert _inherits_from_testcase(cls)
def test_no_testcase_base(self) -> None:
"""
Class not inheriting from TestCase returns False.
"""
import ast
code = "class MyTest(object): pass"
tree = ast.parse(code)
cls = tree.body[0]
assert not _inherits_from_testcase(cls)
class TestCodeValidationErrorDebugDict:
"""Tests for CodeValidationError.to_debug_dict."""
def test_returns_all_fields(self) -> None:
"""
to_debug_dict returns all stored context.
"""
err = CodeValidationError(
"test error",
initial_code="init",
fixed_code="fixed",
final_code="final",
lines_removed=3,
validation_error="some error",
)
d = err.to_debug_dict()
assert "code_validation" == d["stage"]
assert "init" == d["initial_code"]
assert "fixed" == d["fixed_code"]
assert "final" == d["final_code"]
assert 3 == d["lines_removed"]
assert "some error" == d["validation_error"]
class TestFindDecoratorStartLine:
"""Tests for _find_decorator_start_line."""
def test_def_line_zero_returns_zero(self) -> None:
"""
def_line at 0 returns 0 immediately.
"""
assert 0 == _find_decorator_start_line(["def test(): pass"], 0)
def test_finds_decorator(self) -> None:
"""
Decorator above def line is found.
"""
lines = ["@mark\n", "def test(): pass\n"]
assert 0 == _find_decorator_start_line(lines, 1)
def test_skips_blank_and_comment_lines(self) -> None:
"""
Blank and comment lines between decorator and def are skipped.
"""
lines = ["@mark\n", "\n", "# comment\n", "def test(): pass\n"]
assert 0 == _find_decorator_start_line(lines, 3)
class TestValidateTestgenCode:
"""Tests for validate_testgen_code."""
@ -263,6 +377,144 @@ class TestValidateTestgenCode:
validate_testgen_code(code, (3, 12))
class TestValidateTestsIndividually:
"""Tests for _validate_tests_individually."""
def test_broken_preamble_is_repaired(self) -> None:
"""
A broken preamble is repaired before validating tests.
"""
code = (
"import os\ndef bad(\nimport sys\n\n"
"def test_a():\n assert True\n"
)
result, removed = _validate_tests_individually(code, (3, 12))
assert "test_a" in result
assert "import os" in result
def test_all_tests_invalid_returns_preamble_only(self) -> None:
"""
When all tests fail validation, returns preamble with count.
"""
code = (
"import os\n\n"
"def test_a():\n x = (\n\n"
"def test_b():\n y = (\n"
)
result, removed = _validate_tests_individually(code, (3, 12))
assert removed >= 1
def test_no_preamble_validates_tests_alone(self) -> None:
"""
Code with no preamble validates test functions directly.
"""
code = "def test_a():\n assert True\n"
result, removed = _validate_tests_individually(code, (3, 12))
assert "test_a" in result
assert 0 == removed
class TestSplitCodeWithAst:
"""Tests for _split_code_with_ast."""
def test_decorated_non_test_function_in_preamble(self) -> None:
"""
Decorated non-test function is included in preamble with its decorator.
"""
code = (
"import functools\n\n"
"@functools.lru_cache\n"
"def helper():\n return 1\n\n"
"def test_a():\n assert True\n"
)
result = _split_code_with_ast(code, (3, 12))
assert result is not None
preamble, tests = result
assert "@functools.lru_cache" in preamble
assert "helper" in preamble
assert 1 == len(tests)
def test_gap_between_non_test_nodes(self) -> None:
"""
Blank lines between non-test preamble nodes are captured.
"""
code = (
"import os\n\n"
"CONSTANT = 42\n\n"
"def test_a():\n assert True\n"
)
result = _split_code_with_ast(code, (3, 12))
assert result is not None
preamble, tests = result
assert "import os" in preamble
assert "CONSTANT = 42" in preamble
assert 1 == len(tests)
def test_gap_with_code_before_test(self) -> None:
"""
Non-comment content in gap between preamble and test is kept.
"""
code = (
"import os\n\n"
"SETUP = True\n\n"
"EXTRA = 'value'\n\n"
"def test_a():\n assert SETUP\n"
)
result = _split_code_with_ast(code, (3, 12))
assert result is not None
preamble, tests = result
assert "SETUP" in preamble
assert "EXTRA" in preamble
assert 1 == len(tests)
class TestSplitCodeWithRegex:
"""Tests for _split_code_with_regex."""
def test_extracts_test_class(self) -> None:
"""
Test classes are identified via regex pattern.
"""
code = (
"import os\n\n"
"class TestFoo:\n"
" def test_bar(self):\n"
" assert True\n"
)
preamble, tests = _split_code_with_regex(code)
assert "import os" in preamble
assert len(tests) >= 1
assert "class TestFoo" in tests[0]
def test_trailing_blank_lines_stripped(self) -> None:
"""
Trailing blank lines after test functions are removed.
"""
code = (
"import os\n\n"
"def test_a():\n assert True\n\n\n"
)
preamble, tests = _split_code_with_regex(code)
assert 1 == len(tests)
assert not tests[0].endswith("\n\n")
class TestSplitCodeIntoParts:
"""Tests for _split_code_into_parts."""
@ -291,6 +543,54 @@ class TestSplitCodeIntoParts:
assert 0 == len(tests)
def test_decorated_test_function(self) -> None:
"""
Decorated test functions include the decorator.
"""
code = (
"import pytest\n\n"
"@pytest.mark.parametrize('x', [1, 2])\n"
"def test_a(x):\n assert x > 0\n"
)
preamble, tests = _split_code_into_parts(code, (3, 12))
assert "import pytest" in preamble
assert 1 == len(tests)
assert "@pytest.mark.parametrize" in tests[0]
def test_test_class_split(self) -> None:
"""
Test classes are extracted as individual test items.
"""
code = (
"import unittest\n\n"
"class TestFoo(unittest.TestCase):\n"
" def test_bar(self):\n"
" self.assertTrue(True)\n"
)
preamble, tests = _split_code_into_parts(code, (3, 12))
assert "import unittest" in preamble
assert 1 == len(tests)
assert "class TestFoo" in tests[0]
def test_regex_fallback_on_syntax_error(self) -> None:
"""
Code with syntax errors falls back to regex splitter.
"""
code = (
"import os\n\n"
"def test_a():\n assert True\n\n"
"def test_b():\n x = (\n"
)
preamble, tests = _split_code_into_parts(code, (3, 12))
assert "import os" in preamble
assert len(tests) >= 1
class TestRepairPreamble:
"""Tests for _repair_preamble."""
@ -320,6 +620,56 @@ class TestRepairPreamble:
"""
assert "" == _repair_preamble("", (3, 12))
def test_all_bad_lines_returns_empty(self) -> None:
"""
If every line is invalid, returns empty string.
"""
code = "def (\ndef (\ndef (\n"
assert "" == _repair_preamble(code, (3, 12))
def test_becomes_empty_during_iteration(self) -> None:
"""
Lines that become empty after stripping return empty string.
"""
code = " \n \n"
assert "" == _repair_preamble(code, (3, 12))
def test_syntax_error_with_no_lineno(self) -> None:
"""
SyntaxError with lineno=None returns empty string.
"""
with patch(
"codeflash_api.testgen._validate.ast.parse",
side_effect=SyntaxError("err", (None, None, None, None)),
):
assert "" == _repair_preamble("import os\n", (3, 12))
def test_error_idx_out_of_range(self) -> None:
"""
Error on line beyond line count returns empty string.
"""
err = SyntaxError("err")
err.lineno = 999
with patch(
"codeflash_api.testgen._validate.ast.parse",
side_effect=err,
):
assert "" == _repair_preamble("import os\n", (3, 12))
def test_exhausts_max_attempts(self) -> None:
"""
After max_attempts still-invalid code returns empty string.
"""
lines = [f"bad{i}(\n" for i in range(20)]
code = "".join(lines)
result = _repair_preamble(code, (3, 12))
assert "" == result
class TestParseAndValidateLlmOutput:
"""Tests for _parse_and_validate_llm_output."""
@ -347,6 +697,44 @@ class TestParseAndValidateLlmOutput:
"no code here", "def f(): pass", (3, 12, 0)
)
def test_code_block_without_python_content_raises(self) -> None:
"""
Markdown block present but extract_code_block returns None.
"""
content = "```python\n```\n"
with pytest.raises(CodeValidationError):
_parse_and_validate_llm_output(
content, "def f(): pass", (3, 12, 0)
)
def test_extract_code_block_returns_none(self) -> None:
"""
When extract_code_block returns None, raises CodeValidationError.
"""
content = "has ```python marker but extract returns None"
with (
patch(
"codeflash_api.testgen._generate.extract_code_block",
return_value=None,
),
pytest.raises(
CodeValidationError, match="No Python code block"
),
):
_parse_and_validate_llm_output(
content, "def f(): pass", (3, 12, 0)
)
def test_ellipsis_in_output_raises_syntax_error(self) -> None:
"""
Generated code with assigned ellipsis not in source raises SyntaxError.
"""
content = "```python\ndef test_a():\n x = ...\n```\n"
with pytest.raises(SyntaxError, match="Ellipsis"):
_parse_and_validate_llm_output(
content, "def f(): return 1", (3, 12, 0)
)
class TestDidGenerateEllipsis:
"""Tests for _did_generate_ellipsis."""
@ -360,6 +748,51 @@ class TestDidGenerateEllipsis:
"def test_f():\n assert True\n",
)
def test_assigned_ellipsis_in_generated_not_source(self) -> None:
"""
Assigned ellipsis in generated code but not source returns True.
"""
assert _did_generate_ellipsis(
"def f(): return 1",
"def test_f():\n x = ...\n",
)
def test_ellipsis_in_both_returns_false(self) -> None:
"""
Ellipsis present in both source and generated returns False.
"""
assert not _did_generate_ellipsis(
"x = ...",
"def test_f():\n y = ...\n",
)
def test_multi_context_source_without_ellipsis(self) -> None:
"""
Multi-context source with no ellipsis while generated has it returns True.
"""
source = (
"```python:src/mod.py\n"
"def f():\n"
" return 1\n"
"```"
)
generated = "def test_f():\n x = ...\n"
assert _did_generate_ellipsis(source, generated)
def test_multi_context_source_with_ellipsis(self) -> None:
"""
Multi-context source where all files have ellipsis returns False.
"""
source = (
"```python:src/mod.py\n"
"x = ...\n"
"```"
)
generated = "def test_f():\n y = ...\n"
assert not _did_generate_ellipsis(source, generated)
class TestTestGenRequestValidator:
"""Tests for TestGenRequest model validator."""
@ -431,8 +864,8 @@ class TestTestgenEndpoint:
assert 200 == resp.status_code
data = resp.json()
assert "test_a" in data["generated_tests"]
assert "test_a" in data["instrumented_behavior_tests"]
assert "test_a" in data["instrumented_perf_tests"]
assert "instrumented_behavior_tests" not in data
assert "instrumented_perf_tests" not in data
async def test_invalid_trace_id(self, client) -> None:
"""
@ -657,3 +1090,41 @@ class TestTestgenEndpoint:
)
assert 200 == resp.status_code
class TestGenerateTestsNoTestFunctions:
"""Tests for generate_tests when validated code has no test functions."""
async def test_no_tests_after_validation_raises(self) -> None:
"""
Validated code with no test functions raises CodeValidationError.
"""
mock_client = MagicMock()
mock_client.call = AsyncMock(
return_value=MagicMock(
content="```python\nimport os\n```",
cost=0.01,
),
)
with (
patch(
"codeflash_api.testgen._generate.validate_testgen_code",
return_value="import os",
),
pytest.raises(
CodeValidationError,
match="No test functions",
),
):
await generate_tests(
mock_client,
source_code="def f(): return 1",
function_name="f",
test_framework="pytest",
python_version_str="3.12.0",
is_async=False,
trace_id="12345678-1234-4000-8000-000000000000",
user_id="user1",
test_index=0,
)

View file

@ -78,6 +78,24 @@ class TestBuildCoverageContext:
assert "helper_fn" in result
assert "60%" in result
def test_includes_unexecuted_branches(self) -> None:
"""
Unexecuted branches are included when present.
"""
details = CoverageDetails(
coverage_percentage=70.0,
threshold_percentage=90.0,
main_function=CoverageFunctionDetail(
name="compute",
coverage=70.0,
unexecuted_branches=[[10, 12], [15, 18]],
),
)
result = _build_coverage_context(details)
assert "Unexecuted branches" in result
class TestParseReviewVerdicts:
"""Tests for _parse_review_verdicts."""
@ -146,6 +164,28 @@ class TestParseReviewVerdicts:
assert 1 == len(verdicts)
assert "test_bad" == verdicts[0].function_name
def test_skips_non_dict_entries(self) -> None:
"""
Non-dict entries in functions list are silently skipped.
"""
content = json.dumps(
{
"functions": [
"not a dict",
{
"function_name": "test_a",
"verdict": "repair",
"reason": "bad",
},
]
}
)
verdicts = _parse_review_verdicts(content)
assert 1 == len(verdicts)
assert "test_a" == verdicts[0].function_name
def test_empty_functions_returns_empty(self) -> None:
"""
Empty functions list yields no verdicts.
@ -178,6 +218,18 @@ class TestParseReviewVerdicts:
assert 1 == len(verdicts)
def test_non_dict_non_list_json_returns_empty(self) -> None:
"""
Parsed JSON that is neither dict nor list returns empty list.
"""
assert [] == _parse_review_verdicts('"just a string"')
def test_integer_json_returns_empty(self) -> None:
"""
Parsed JSON that is an integer returns empty list.
"""
assert [] == _parse_review_verdicts("42")
class TestGetSyntaxError:
"""Tests for _get_syntax_error."""
@ -312,6 +364,35 @@ class TestTestgenReviewEndpoint:
assert "test_a" == fns[0]["function_name"]
assert "repair" == fns[0]["verdict"]
async def test_llm_failure_returns_empty_verdicts(
self, client, mock_llm_client
) -> None: # type: ignore[no-untyped-def]
"""
LLM failure during review falls back to empty AI verdicts.
"""
mock_llm_client.call = AsyncMock(
side_effect=RuntimeError("LLM down"),
)
resp = await client.post(
"/ai/testgen_review",
json={
"tests": [
{
"test_source": "def test_a(): pass",
"test_index": 0,
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 200 == resp.status_code
assert 1 == len(resp.json()["reviews"])
assert [] == resp.json()["reviews"][0]["functions"]
class TestTestgenRepairEndpoint:
"""Integration tests for POST /ai/testgen_repair."""
@ -351,6 +432,8 @@ class TestTestgenRepairEndpoint:
assert 200 == resp.status_code
data = resp.json()
assert repaired == data["generated_tests"].strip()
assert "instrumented_behavior_tests" not in data
assert "instrumented_perf_tests" not in data
async def test_invalid_trace_id(self, client) -> None: # type: ignore[no-untyped-def]
"""
@ -508,6 +591,41 @@ class TestTestgenRepairEndpoint:
assert 422 == resp.status_code
async def test_no_code_block_retries_then_fails(
self, client, mock_llm_client
) -> None: # type: ignore[no-untyped-def]
"""
LLM responses without code blocks exhaust retries and return 422.
"""
mock_llm_client.call = AsyncMock(
return_value=MagicMock(
content="I cannot fix this test.",
cost=0.01,
),
)
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "def test_a(): pass",
"functions_to_repair": [
{
"function_name": "test_a",
"reason": "broken",
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"test_framework": "pytest",
"test_timeout": 60,
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 422 == resp.status_code
class TestTestRepairRequestValidator:
"""Tests for TestRepairRequest model validator."""