Add unit tests for _candidate_gen (generate, repair, refinement)

Covers happy paths and error paths for generate_candidates,
repair_failed_candidates, and generate_refinement_candidates.
Tests AI service errors, unparseable markdown, missing runtime
data, and repair failures.
This commit is contained in:
Kevin Turcios 2026-04-23 02:23:52 -05:00
parent 815eba00c0
commit cf7cf60936

View file

@ -0,0 +1,376 @@
"""Tests for candidate generation strategies."""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from codeflash_core import Candidate
from codeflash_python.pipeline._candidate_gen import (
generate_candidates,
generate_refinement_candidates,
repair_failed_candidates,
)
def _mock_ctx() -> MagicMock:
ctx = MagicMock()
ctx.plugin.language_id = "python"
ctx.project_root = MagicMock()
ctx.test_cfg.pytest_cmd = "pytest"
return ctx
def _mock_fn_input() -> MagicMock:
fn_input = MagicMock()
fn_input.function.qualified_name = "mod.func"
fn_input.function.is_async = False
fn_input.source_code = "def func(): pass"
return fn_input
def _mock_code_context() -> MagicMock:
code_context = MagicMock()
code_context.read_writable_code.markdown = (
"```python\ndef func(): pass\n```"
)
code_context.read_only = ""
code_context.helper_functions = []
return code_context
class TestGenerateCandidates:
"""Tests for generate_candidates."""
def test_returns_parsed_candidates(self) -> None:
"""Markdown-fenced code is parsed into plain Python."""
ctx = _mock_ctx()
ctx.ai_client.get_candidates.return_value = [
Candidate(
code="```python\ndef func(): return 1\n```",
explanation="optimized",
candidate_id="c1",
),
]
result = generate_candidates(
ctx=ctx,
function_trace_id="trace-1",
fn_input=_mock_fn_input(),
code_context=_mock_code_context(),
)
assert 1 == len(result)
assert "def func(): return 1" == result[0].code.strip()
assert "c1" == result[0].candidate_id
def test_skips_unparseable_candidates(self) -> None:
"""Candidates with no parseable code blocks are dropped."""
ctx = _mock_ctx()
ctx.ai_client.get_candidates.return_value = [
Candidate(
code="no code fences here",
explanation="bad",
candidate_id="c1",
),
]
result = generate_candidates(
ctx=ctx,
function_trace_id="trace-1",
fn_input=_mock_fn_input(),
code_context=_mock_code_context(),
)
assert [] == result
def test_ai_service_error_returns_empty(self) -> None:
"""AI service exceptions are caught and return an empty list."""
ctx = _mock_ctx()
ctx.ai_client.get_candidates.side_effect = RuntimeError("boom")
result = generate_candidates(
ctx=ctx,
function_trace_id="trace-1",
fn_input=_mock_fn_input(),
code_context=_mock_code_context(),
)
assert [] == result
def test_baseline_runtime_forwarded(self) -> None:
"""When baseline is provided, runtime and loop count are sent."""
ctx = _mock_ctx()
ctx.ai_client.get_candidates.return_value = []
baseline = MagicMock()
baseline.runtime = 50000
baseline.benchmarking_test_results.number_of_loops.return_value = 10
generate_candidates(
ctx=ctx,
function_trace_id="trace-1",
fn_input=_mock_fn_input(),
code_context=_mock_code_context(),
baseline=baseline,
)
request = ctx.ai_client.get_candidates.call_args[0][0]
assert 50000 == request.baseline_runtime_ns
assert 10 == request.loop_count
def test_no_baseline_sends_none(self) -> None:
"""Without baseline, runtime fields are None."""
ctx = _mock_ctx()
ctx.ai_client.get_candidates.return_value = []
generate_candidates(
ctx=ctx,
function_trace_id="trace-1",
fn_input=_mock_fn_input(),
code_context=_mock_code_context(),
)
request = ctx.ai_client.get_candidates.call_args[0][0]
assert request.baseline_runtime_ns is None
assert request.loop_count is None
def test_multiple_candidates_all_parsed(self) -> None:
"""Multiple valid candidates are all returned."""
ctx = _mock_ctx()
ctx.ai_client.get_candidates.return_value = [
Candidate(
code="```python\ndef func(): return 1\n```",
explanation="v1",
candidate_id="c1",
),
Candidate(
code="```python\ndef func(): return 2\n```",
explanation="v2",
candidate_id="c2",
),
]
result = generate_candidates(
ctx=ctx,
function_trace_id="trace-1",
fn_input=_mock_fn_input(),
code_context=_mock_code_context(),
)
assert 2 == len(result)
assert "c1" == result[0].candidate_id
assert "c2" == result[1].candidate_id
class TestRepairFailedCandidates:
"""Tests for repair_failed_candidates."""
def test_empty_diffs_returns_empty(self) -> None:
"""No failed diffs means nothing to repair."""
result = repair_failed_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
failed_candidate_diffs={},
failed_candidate_code={},
fn_input=_mock_fn_input(),
)
assert [] == result
@patch("codeflash_python.ai._refinement.code_repair")
def test_repairs_failed_candidate(
self,
mock_repair: MagicMock,
) -> None:
"""A failed candidate is sent for repair and returned."""
repaired = Candidate(
code="def func(): return 42",
explanation="fixed",
candidate_id="c1-repaired",
)
mock_repair.return_value = repaired
result = repair_failed_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
failed_candidate_diffs={"c1": ["diff1"]},
failed_candidate_code={"c1": "def func(): return bad"},
fn_input=_mock_fn_input(),
)
assert 1 == len(result)
assert repaired is result[0]
@patch("codeflash_python.ai._refinement.code_repair")
def test_repair_exception_skips_candidate(
self,
mock_repair: MagicMock,
) -> None:
"""When repair raises, the candidate is skipped."""
mock_repair.side_effect = RuntimeError("repair failed")
result = repair_failed_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
failed_candidate_diffs={"c1": ["diff1"]},
failed_candidate_code={"c1": "def func(): pass"},
fn_input=_mock_fn_input(),
)
assert [] == result
@patch("codeflash_python.ai._refinement.code_repair")
def test_repair_returns_none_skips(
self,
mock_repair: MagicMock,
) -> None:
"""When repair returns None, the candidate is dropped."""
mock_repair.return_value = None
result = repair_failed_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
failed_candidate_diffs={"c1": ["diff1"]},
failed_candidate_code={"c1": "def func(): pass"},
fn_input=_mock_fn_input(),
)
assert [] == result
def test_missing_code_skips_candidate(self) -> None:
"""If candidate code is missing from the dict, it's skipped."""
result = repair_failed_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
failed_candidate_diffs={"c1": ["diff1"]},
failed_candidate_code={},
fn_input=_mock_fn_input(),
)
assert [] == result
class TestGenerateRefinementCandidates:
"""Tests for generate_refinement_candidates."""
def test_empty_valid_returns_empty(self) -> None:
"""No valid candidates means nothing to refine."""
eval_ctx = MagicMock()
result = generate_refinement_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
baseline_lp_markdown="",
valid=[],
eval_ctx=eval_ctx,
fn_input=_mock_fn_input(),
baseline=MagicMock(runtime=100000),
code_context=_mock_code_context(),
)
assert [] == result
@patch(
"codeflash_python.ai._refinement.optimize_code_refinement",
)
def test_refinement_returns_candidates(
self,
mock_refine: MagicMock,
) -> None:
"""Valid candidates produce refinement requests."""
refined = [
Candidate(
code="def func(): return 99",
explanation="refined",
candidate_id="r1",
),
]
mock_refine.return_value = refined
valid = [
Candidate(
code="def func(): return 1",
explanation="orig",
candidate_id="c1",
),
]
eval_ctx = MagicMock()
eval_ctx.optimized_runtimes = {"c1": 25000}
eval_ctx.speedup_ratios = {"c1": 4.0}
result = generate_refinement_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
baseline_lp_markdown="line profile data",
valid=valid,
eval_ctx=eval_ctx,
fn_input=_mock_fn_input(),
baseline=MagicMock(runtime=100000),
code_context=_mock_code_context(),
)
assert refined == result
@patch(
"codeflash_python.ai._refinement.optimize_code_refinement",
)
def test_refinement_exception_returns_empty(
self,
mock_refine: MagicMock,
) -> None:
"""Refinement exceptions are caught."""
mock_refine.side_effect = RuntimeError("API error")
valid = [
Candidate(
code="def func(): return 1",
explanation="orig",
candidate_id="c1",
),
]
eval_ctx = MagicMock()
eval_ctx.optimized_runtimes = {"c1": 25000}
eval_ctx.speedup_ratios = {"c1": 4.0}
result = generate_refinement_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
baseline_lp_markdown="",
valid=valid,
eval_ctx=eval_ctx,
fn_input=_mock_fn_input(),
baseline=MagicMock(runtime=100000),
code_context=_mock_code_context(),
)
assert [] == result
def test_missing_runtime_skips_candidate(self) -> None:
"""Candidates without runtime data are skipped."""
valid = [
Candidate(
code="def func(): return 1",
explanation="orig",
candidate_id="c1",
),
]
eval_ctx = MagicMock()
eval_ctx.optimized_runtimes = {}
eval_ctx.speedup_ratios = {}
with patch(
"codeflash_python.ai._refinement.optimize_code_refinement",
) as mock_refine:
result = generate_refinement_candidates(
ai_client=MagicMock(),
function_trace_id="trace-1",
baseline_lp_markdown="",
valid=valid,
eval_ctx=eval_ctx,
fn_input=_mock_fn_input(),
baseline=MagicMock(runtime=100000),
code_context=_mock_code_context(),
)
mock_refine.assert_not_called()
assert [] == result