mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Add unit tests for _candidate_gen (generate, repair, refinement)
Covers happy paths and error paths for generate_candidates, repair_failed_candidates, and generate_refinement_candidates. Tests AI service errors, unparseable markdown, missing runtime data, and repair failures.
This commit is contained in:
parent
815eba00c0
commit
cf7cf60936
1 changed files with 376 additions and 0 deletions
376
packages/codeflash-python/tests/test_candidate_gen.py
Normal file
376
packages/codeflash-python/tests/test_candidate_gen.py
Normal file
|
|
@ -0,0 +1,376 @@
|
|||
"""Tests for candidate generation strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from codeflash_core import Candidate
|
||||
from codeflash_python.pipeline._candidate_gen import (
|
||||
generate_candidates,
|
||||
generate_refinement_candidates,
|
||||
repair_failed_candidates,
|
||||
)
|
||||
|
||||
|
||||
def _mock_ctx() -> MagicMock:
|
||||
ctx = MagicMock()
|
||||
ctx.plugin.language_id = "python"
|
||||
ctx.project_root = MagicMock()
|
||||
ctx.test_cfg.pytest_cmd = "pytest"
|
||||
return ctx
|
||||
|
||||
|
||||
def _mock_fn_input() -> MagicMock:
|
||||
fn_input = MagicMock()
|
||||
fn_input.function.qualified_name = "mod.func"
|
||||
fn_input.function.is_async = False
|
||||
fn_input.source_code = "def func(): pass"
|
||||
return fn_input
|
||||
|
||||
|
||||
def _mock_code_context() -> MagicMock:
|
||||
code_context = MagicMock()
|
||||
code_context.read_writable_code.markdown = (
|
||||
"```python\ndef func(): pass\n```"
|
||||
)
|
||||
code_context.read_only = ""
|
||||
code_context.helper_functions = []
|
||||
return code_context
|
||||
|
||||
|
||||
class TestGenerateCandidates:
|
||||
"""Tests for generate_candidates."""
|
||||
|
||||
def test_returns_parsed_candidates(self) -> None:
|
||||
"""Markdown-fenced code is parsed into plain Python."""
|
||||
ctx = _mock_ctx()
|
||||
ctx.ai_client.get_candidates.return_value = [
|
||||
Candidate(
|
||||
code="```python\ndef func(): return 1\n```",
|
||||
explanation="optimized",
|
||||
candidate_id="c1",
|
||||
),
|
||||
]
|
||||
|
||||
result = generate_candidates(
|
||||
ctx=ctx,
|
||||
function_trace_id="trace-1",
|
||||
fn_input=_mock_fn_input(),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert 1 == len(result)
|
||||
assert "def func(): return 1" == result[0].code.strip()
|
||||
assert "c1" == result[0].candidate_id
|
||||
|
||||
def test_skips_unparseable_candidates(self) -> None:
|
||||
"""Candidates with no parseable code blocks are dropped."""
|
||||
ctx = _mock_ctx()
|
||||
ctx.ai_client.get_candidates.return_value = [
|
||||
Candidate(
|
||||
code="no code fences here",
|
||||
explanation="bad",
|
||||
candidate_id="c1",
|
||||
),
|
||||
]
|
||||
|
||||
result = generate_candidates(
|
||||
ctx=ctx,
|
||||
function_trace_id="trace-1",
|
||||
fn_input=_mock_fn_input(),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
def test_ai_service_error_returns_empty(self) -> None:
|
||||
"""AI service exceptions are caught and return an empty list."""
|
||||
ctx = _mock_ctx()
|
||||
ctx.ai_client.get_candidates.side_effect = RuntimeError("boom")
|
||||
|
||||
result = generate_candidates(
|
||||
ctx=ctx,
|
||||
function_trace_id="trace-1",
|
||||
fn_input=_mock_fn_input(),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
def test_baseline_runtime_forwarded(self) -> None:
|
||||
"""When baseline is provided, runtime and loop count are sent."""
|
||||
ctx = _mock_ctx()
|
||||
ctx.ai_client.get_candidates.return_value = []
|
||||
|
||||
baseline = MagicMock()
|
||||
baseline.runtime = 50000
|
||||
baseline.benchmarking_test_results.number_of_loops.return_value = 10
|
||||
|
||||
generate_candidates(
|
||||
ctx=ctx,
|
||||
function_trace_id="trace-1",
|
||||
fn_input=_mock_fn_input(),
|
||||
code_context=_mock_code_context(),
|
||||
baseline=baseline,
|
||||
)
|
||||
|
||||
request = ctx.ai_client.get_candidates.call_args[0][0]
|
||||
assert 50000 == request.baseline_runtime_ns
|
||||
assert 10 == request.loop_count
|
||||
|
||||
def test_no_baseline_sends_none(self) -> None:
|
||||
"""Without baseline, runtime fields are None."""
|
||||
ctx = _mock_ctx()
|
||||
ctx.ai_client.get_candidates.return_value = []
|
||||
|
||||
generate_candidates(
|
||||
ctx=ctx,
|
||||
function_trace_id="trace-1",
|
||||
fn_input=_mock_fn_input(),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
request = ctx.ai_client.get_candidates.call_args[0][0]
|
||||
assert request.baseline_runtime_ns is None
|
||||
assert request.loop_count is None
|
||||
|
||||
def test_multiple_candidates_all_parsed(self) -> None:
|
||||
"""Multiple valid candidates are all returned."""
|
||||
ctx = _mock_ctx()
|
||||
ctx.ai_client.get_candidates.return_value = [
|
||||
Candidate(
|
||||
code="```python\ndef func(): return 1\n```",
|
||||
explanation="v1",
|
||||
candidate_id="c1",
|
||||
),
|
||||
Candidate(
|
||||
code="```python\ndef func(): return 2\n```",
|
||||
explanation="v2",
|
||||
candidate_id="c2",
|
||||
),
|
||||
]
|
||||
|
||||
result = generate_candidates(
|
||||
ctx=ctx,
|
||||
function_trace_id="trace-1",
|
||||
fn_input=_mock_fn_input(),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert 2 == len(result)
|
||||
assert "c1" == result[0].candidate_id
|
||||
assert "c2" == result[1].candidate_id
|
||||
|
||||
|
||||
class TestRepairFailedCandidates:
|
||||
"""Tests for repair_failed_candidates."""
|
||||
|
||||
def test_empty_diffs_returns_empty(self) -> None:
|
||||
"""No failed diffs means nothing to repair."""
|
||||
result = repair_failed_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
failed_candidate_diffs={},
|
||||
failed_candidate_code={},
|
||||
fn_input=_mock_fn_input(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
@patch("codeflash_python.ai._refinement.code_repair")
|
||||
def test_repairs_failed_candidate(
|
||||
self,
|
||||
mock_repair: MagicMock,
|
||||
) -> None:
|
||||
"""A failed candidate is sent for repair and returned."""
|
||||
repaired = Candidate(
|
||||
code="def func(): return 42",
|
||||
explanation="fixed",
|
||||
candidate_id="c1-repaired",
|
||||
)
|
||||
mock_repair.return_value = repaired
|
||||
|
||||
result = repair_failed_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
failed_candidate_diffs={"c1": ["diff1"]},
|
||||
failed_candidate_code={"c1": "def func(): return bad"},
|
||||
fn_input=_mock_fn_input(),
|
||||
)
|
||||
|
||||
assert 1 == len(result)
|
||||
assert repaired is result[0]
|
||||
|
||||
@patch("codeflash_python.ai._refinement.code_repair")
|
||||
def test_repair_exception_skips_candidate(
|
||||
self,
|
||||
mock_repair: MagicMock,
|
||||
) -> None:
|
||||
"""When repair raises, the candidate is skipped."""
|
||||
mock_repair.side_effect = RuntimeError("repair failed")
|
||||
|
||||
result = repair_failed_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
failed_candidate_diffs={"c1": ["diff1"]},
|
||||
failed_candidate_code={"c1": "def func(): pass"},
|
||||
fn_input=_mock_fn_input(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
@patch("codeflash_python.ai._refinement.code_repair")
|
||||
def test_repair_returns_none_skips(
|
||||
self,
|
||||
mock_repair: MagicMock,
|
||||
) -> None:
|
||||
"""When repair returns None, the candidate is dropped."""
|
||||
mock_repair.return_value = None
|
||||
|
||||
result = repair_failed_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
failed_candidate_diffs={"c1": ["diff1"]},
|
||||
failed_candidate_code={"c1": "def func(): pass"},
|
||||
fn_input=_mock_fn_input(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
def test_missing_code_skips_candidate(self) -> None:
|
||||
"""If candidate code is missing from the dict, it's skipped."""
|
||||
result = repair_failed_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
failed_candidate_diffs={"c1": ["diff1"]},
|
||||
failed_candidate_code={},
|
||||
fn_input=_mock_fn_input(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
|
||||
class TestGenerateRefinementCandidates:
|
||||
"""Tests for generate_refinement_candidates."""
|
||||
|
||||
def test_empty_valid_returns_empty(self) -> None:
|
||||
"""No valid candidates means nothing to refine."""
|
||||
eval_ctx = MagicMock()
|
||||
|
||||
result = generate_refinement_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
baseline_lp_markdown="",
|
||||
valid=[],
|
||||
eval_ctx=eval_ctx,
|
||||
fn_input=_mock_fn_input(),
|
||||
baseline=MagicMock(runtime=100000),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
@patch(
|
||||
"codeflash_python.ai._refinement.optimize_code_refinement",
|
||||
)
|
||||
def test_refinement_returns_candidates(
|
||||
self,
|
||||
mock_refine: MagicMock,
|
||||
) -> None:
|
||||
"""Valid candidates produce refinement requests."""
|
||||
refined = [
|
||||
Candidate(
|
||||
code="def func(): return 99",
|
||||
explanation="refined",
|
||||
candidate_id="r1",
|
||||
),
|
||||
]
|
||||
mock_refine.return_value = refined
|
||||
|
||||
valid = [
|
||||
Candidate(
|
||||
code="def func(): return 1",
|
||||
explanation="orig",
|
||||
candidate_id="c1",
|
||||
),
|
||||
]
|
||||
eval_ctx = MagicMock()
|
||||
eval_ctx.optimized_runtimes = {"c1": 25000}
|
||||
eval_ctx.speedup_ratios = {"c1": 4.0}
|
||||
|
||||
result = generate_refinement_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
baseline_lp_markdown="line profile data",
|
||||
valid=valid,
|
||||
eval_ctx=eval_ctx,
|
||||
fn_input=_mock_fn_input(),
|
||||
baseline=MagicMock(runtime=100000),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert refined == result
|
||||
|
||||
@patch(
|
||||
"codeflash_python.ai._refinement.optimize_code_refinement",
|
||||
)
|
||||
def test_refinement_exception_returns_empty(
|
||||
self,
|
||||
mock_refine: MagicMock,
|
||||
) -> None:
|
||||
"""Refinement exceptions are caught."""
|
||||
mock_refine.side_effect = RuntimeError("API error")
|
||||
|
||||
valid = [
|
||||
Candidate(
|
||||
code="def func(): return 1",
|
||||
explanation="orig",
|
||||
candidate_id="c1",
|
||||
),
|
||||
]
|
||||
eval_ctx = MagicMock()
|
||||
eval_ctx.optimized_runtimes = {"c1": 25000}
|
||||
eval_ctx.speedup_ratios = {"c1": 4.0}
|
||||
|
||||
result = generate_refinement_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
baseline_lp_markdown="",
|
||||
valid=valid,
|
||||
eval_ctx=eval_ctx,
|
||||
fn_input=_mock_fn_input(),
|
||||
baseline=MagicMock(runtime=100000),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
assert [] == result
|
||||
|
||||
def test_missing_runtime_skips_candidate(self) -> None:
|
||||
"""Candidates without runtime data are skipped."""
|
||||
valid = [
|
||||
Candidate(
|
||||
code="def func(): return 1",
|
||||
explanation="orig",
|
||||
candidate_id="c1",
|
||||
),
|
||||
]
|
||||
eval_ctx = MagicMock()
|
||||
eval_ctx.optimized_runtimes = {}
|
||||
eval_ctx.speedup_ratios = {}
|
||||
|
||||
with patch(
|
||||
"codeflash_python.ai._refinement.optimize_code_refinement",
|
||||
) as mock_refine:
|
||||
result = generate_refinement_candidates(
|
||||
ai_client=MagicMock(),
|
||||
function_trace_id="trace-1",
|
||||
baseline_lp_markdown="",
|
||||
valid=valid,
|
||||
eval_ctx=eval_ctx,
|
||||
fn_input=_mock_fn_input(),
|
||||
baseline=MagicMock(runtime=100000),
|
||||
code_context=_mock_code_context(),
|
||||
)
|
||||
|
||||
mock_refine.assert_not_called()
|
||||
assert [] == result
|
||||
Loading…
Reference in a new issue