mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Add integration tests for parallel candidate evaluation
Tests overlay isolation, concurrent dispatch, thread safety, exception handling, and the full evaluate_candidate_isolated flow with mocked subprocess execution.
This commit is contained in:
parent
2cea1ab784
commit
ebd239bbbc
1 changed files with 564 additions and 0 deletions
|
|
@ -0,0 +1,564 @@
|
|||
"""Integration tests for parallel candidate evaluation.
|
||||
|
||||
Exercises the full evaluation pipeline — overlay creation, candidate
|
||||
code replacement, behavioral test comparison, benchmarking, and
|
||||
concurrent dispatch — with mocked subprocess execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_core import Candidate, EvaluationContext
|
||||
from codeflash_python._model import FunctionToOptimize
|
||||
from codeflash_python.pipeline._candidate_eval import (
|
||||
evaluate_candidate_isolated,
|
||||
run_tests_and_benchmark,
|
||||
)
|
||||
from codeflash_python.pipeline._function_optimizer import (
|
||||
_evaluate_batch_parallel,
|
||||
)
|
||||
from codeflash_python.testing.models import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestConfig,
|
||||
TestFile,
|
||||
TestFiles,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash_python.verification.models import OriginalCodeBaseline
|
||||
|
||||
# -- Fixtures --------------------------------------------------------
|
||||
|
||||
|
||||
def _make_invocation(
|
||||
*,
|
||||
did_pass: bool = True,
|
||||
runtime: int = 1_000_000,
|
||||
loop_index: int = 0,
|
||||
test_name: str = "test_fn",
|
||||
file_name: str = "test_mod.py",
|
||||
) -> FunctionTestInvocation:
|
||||
"""Build a minimal passing invocation."""
|
||||
from codeflash_python.test_discovery.models import (
|
||||
TestType,
|
||||
)
|
||||
|
||||
return FunctionTestInvocation(
|
||||
loop_index=loop_index,
|
||||
id=InvocationId(
|
||||
test_module_path=file_name,
|
||||
test_class_name=None,
|
||||
test_function_name=test_name,
|
||||
function_getting_tested="target_fn",
|
||||
iteration_id="0",
|
||||
),
|
||||
file_name=Path(file_name),
|
||||
did_pass=did_pass,
|
||||
runtime=runtime,
|
||||
test_framework="pytest",
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
return_value=42,
|
||||
timed_out=False,
|
||||
)
|
||||
|
||||
|
||||
def _make_test_results(
|
||||
*, runtime: int = 1_000_000, count: int = 1
|
||||
) -> TestResults:
|
||||
"""Build TestResults with *count* passing invocations."""
|
||||
tr = TestResults()
|
||||
for i in range(count):
|
||||
tr.add(
|
||||
_make_invocation(
|
||||
runtime=runtime,
|
||||
loop_index=i,
|
||||
test_name=f"test_fn_{i}",
|
||||
),
|
||||
)
|
||||
return tr
|
||||
|
||||
|
||||
@pytest.fixture(name="project")
|
||||
def _project(tmp_path: Path) -> tuple[Path, Path]:
|
||||
"""A minimal src-layout project with one module."""
|
||||
root = tmp_path / "project"
|
||||
root.mkdir()
|
||||
(root / "pyproject.toml").write_text("[project]\nname='demo'\n")
|
||||
src = root / "src"
|
||||
pkg = src / "mypkg"
|
||||
pkg.mkdir(parents=True)
|
||||
(pkg / "__init__.py").write_text("")
|
||||
mod = pkg / "core.py"
|
||||
mod.write_text("def target_fn():\n return 42\n")
|
||||
return root, mod
|
||||
|
||||
|
||||
@pytest.fixture(name="fn_input")
|
||||
def _fn_input(project: tuple[Path, Path]) -> Any:
|
||||
"""A FunctionInput pointing at the project module."""
|
||||
from codeflash_python.pipeline._optimizer import FunctionInput
|
||||
|
||||
root, mod = project
|
||||
source = mod.read_text("utf-8")
|
||||
func = FunctionToOptimize(
|
||||
function_name="target_fn",
|
||||
file_path=mod,
|
||||
)
|
||||
return FunctionInput(
|
||||
function=func,
|
||||
module_path=mod,
|
||||
source_code=source,
|
||||
normalized_code=source,
|
||||
module_ast=ast.parse(source),
|
||||
validated_code={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="baseline")
|
||||
def _baseline() -> OriginalCodeBaseline:
|
||||
"""A baseline with runtime of 1ms."""
|
||||
behavior = _make_test_results(runtime=1_000_000)
|
||||
bench = _make_test_results(runtime=1_000_000)
|
||||
lp = TestResults()
|
||||
return OriginalCodeBaseline(
|
||||
behavior_test_results=behavior,
|
||||
benchmarking_test_results=bench,
|
||||
runtime=1_000_000,
|
||||
line_profile_results=lp,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="ctx")
|
||||
def _ctx(project: tuple[Path, Path]) -> Any:
|
||||
"""A minimal OptimizationContext with mocked dependencies."""
|
||||
root, _ = project
|
||||
test_cfg = TestConfig(
|
||||
tests_project_rootdir=root,
|
||||
pytest_cmd="pytest",
|
||||
tests_root=str(root / "tests"),
|
||||
)
|
||||
ctx = MagicMock()
|
||||
ctx.project_root = root
|
||||
ctx.test_cfg = test_cfg
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture(name="test_files")
|
||||
def _test_files(project: tuple[Path, Path]) -> TestFiles:
|
||||
"""A TestFiles with one dummy test file."""
|
||||
root, _ = project
|
||||
test_dir = root / "tests"
|
||||
test_dir.mkdir(exist_ok=True)
|
||||
test_path = test_dir / "test_core.py"
|
||||
test_path.write_text("def test_target(): pass\n")
|
||||
return TestFiles(
|
||||
test_files=[TestFile(original_file_path=test_path)],
|
||||
)
|
||||
|
||||
|
||||
# -- Tests -----------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_behavioral_ok(
|
||||
baseline_results: TestResults,
|
||||
bench_runtime: int = 500_000,
|
||||
) -> tuple[Any, Any, Any]:
|
||||
"""Return mock callables for behavioral + benchmark success."""
|
||||
xml_sentinel = Path("/fake/results.xml")
|
||||
|
||||
def _run_behavioral(**kwargs: Any) -> tuple[Path, Any, Any, Any]:
|
||||
return (xml_sentinel, MagicMock(), None, None)
|
||||
|
||||
def _run_benchmarking(**kwargs: Any) -> tuple[Path, Any]:
|
||||
return (xml_sentinel, MagicMock())
|
||||
|
||||
def _parse(
|
||||
*,
|
||||
test_xml_path: Any,
|
||||
test_files: Any,
|
||||
test_config: Any,
|
||||
optimization_iteration: int,
|
||||
run_result: Any,
|
||||
) -> TestResults:
|
||||
return _make_test_results(runtime=bench_runtime)
|
||||
|
||||
return _run_behavioral, _run_benchmarking, _parse
|
||||
|
||||
|
||||
class TestEvaluateCandidateIsolated:
|
||||
"""evaluate_candidate_isolated with project overlays."""
|
||||
|
||||
def test_successful_candidate_records_speedup(
|
||||
self,
|
||||
project: tuple[Path, Path],
|
||||
fn_input: Any,
|
||||
baseline: OriginalCodeBaseline,
|
||||
ctx: Any,
|
||||
test_files: TestFiles,
|
||||
) -> None:
|
||||
"""A faster candidate gets a positive speedup recorded."""
|
||||
candidate = Candidate(
|
||||
code="def target_fn():\n return 42\n",
|
||||
explanation="optimized",
|
||||
candidate_id="c1",
|
||||
)
|
||||
eval_ctx = EvaluationContext()
|
||||
failed_code: dict[str, str] = {}
|
||||
failed_diffs: dict[str, list[Any]] = {}
|
||||
bench_results: dict[str, TestResults] = {}
|
||||
|
||||
run_beh, run_bench, parse = _mock_behavioral_ok(
|
||||
baseline.behavior_test_results,
|
||||
bench_runtime=500_000,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_behavioral_tests",
|
||||
side_effect=run_beh,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_benchmarking_tests",
|
||||
side_effect=run_bench,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".parse_test_results",
|
||||
side_effect=parse,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".compare_test_results",
|
||||
return_value=(True, []),
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.verification._baseline"
|
||||
".add_async_perf_decorator",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.verification._baseline"
|
||||
".revert_async_decorator",
|
||||
),
|
||||
):
|
||||
speedup = evaluate_candidate_isolated(
|
||||
candidate=candidate,
|
||||
fn_input=fn_input,
|
||||
baseline=baseline,
|
||||
eval_ctx=eval_ctx,
|
||||
test_files=test_files,
|
||||
test_env={},
|
||||
ctx=ctx,
|
||||
failed_candidate_code=failed_code,
|
||||
failed_candidate_diffs=failed_diffs,
|
||||
candidate_bench_results=bench_results,
|
||||
)
|
||||
|
||||
assert speedup is not None
|
||||
assert speedup > 0
|
||||
assert eval_ctx.is_correct["c1"] is True
|
||||
assert "c1" in eval_ctx.optimizations_post
|
||||
|
||||
def test_original_source_unchanged(
|
||||
self,
|
||||
project: tuple[Path, Path],
|
||||
fn_input: Any,
|
||||
baseline: OriginalCodeBaseline,
|
||||
ctx: Any,
|
||||
test_files: TestFiles,
|
||||
) -> None:
|
||||
"""The original module file is never modified."""
|
||||
_, mod = project
|
||||
original_content = mod.read_text("utf-8")
|
||||
|
||||
candidate = Candidate(
|
||||
code="def target_fn():\n return 99\n",
|
||||
explanation="changed",
|
||||
candidate_id="c2",
|
||||
)
|
||||
eval_ctx = EvaluationContext()
|
||||
|
||||
run_beh, run_bench, parse = _mock_behavioral_ok(
|
||||
baseline.behavior_test_results,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_behavioral_tests",
|
||||
side_effect=run_beh,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_benchmarking_tests",
|
||||
side_effect=run_bench,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".parse_test_results",
|
||||
side_effect=parse,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".compare_test_results",
|
||||
return_value=(True, []),
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.verification._baseline"
|
||||
".add_async_perf_decorator",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.verification._baseline"
|
||||
".revert_async_decorator",
|
||||
),
|
||||
):
|
||||
evaluate_candidate_isolated(
|
||||
candidate=candidate,
|
||||
fn_input=fn_input,
|
||||
baseline=baseline,
|
||||
eval_ctx=eval_ctx,
|
||||
test_files=test_files,
|
||||
test_env={},
|
||||
ctx=ctx,
|
||||
failed_candidate_code={},
|
||||
failed_candidate_diffs={},
|
||||
candidate_bench_results={},
|
||||
)
|
||||
|
||||
assert original_content == mod.read_text("utf-8")
|
||||
|
||||
def test_failed_candidate_stored(
|
||||
self,
|
||||
project: tuple[Path, Path],
|
||||
fn_input: Any,
|
||||
baseline: OriginalCodeBaseline,
|
||||
ctx: Any,
|
||||
test_files: TestFiles,
|
||||
) -> None:
|
||||
"""A failing candidate is recorded in failed_candidate_code."""
|
||||
candidate = Candidate(
|
||||
code="def target_fn():\n return 99\n",
|
||||
explanation="bad",
|
||||
candidate_id="c3",
|
||||
)
|
||||
eval_ctx = EvaluationContext()
|
||||
failed_code: dict[str, str] = {}
|
||||
|
||||
run_beh, _, parse = _mock_behavioral_ok(
|
||||
baseline.behavior_test_results,
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_behavioral_tests",
|
||||
side_effect=run_beh,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".parse_test_results",
|
||||
side_effect=parse,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".compare_test_results",
|
||||
return_value=(False, []),
|
||||
),
|
||||
):
|
||||
result = evaluate_candidate_isolated(
|
||||
candidate=candidate,
|
||||
fn_input=fn_input,
|
||||
baseline=baseline,
|
||||
eval_ctx=eval_ctx,
|
||||
test_files=test_files,
|
||||
test_env={},
|
||||
ctx=ctx,
|
||||
failed_candidate_code=failed_code,
|
||||
failed_candidate_diffs={},
|
||||
candidate_bench_results={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert "c3" in failed_code
|
||||
|
||||
|
||||
class TestEvaluateBatchParallel:
|
||||
"""_evaluate_batch_parallel concurrent dispatch."""
|
||||
|
||||
def test_single_candidate_runs_sequentially(self) -> None:
|
||||
"""One candidate does not spawn a thread pool."""
|
||||
called: list[str] = []
|
||||
|
||||
def try_fn(c: Candidate) -> None:
|
||||
called.append(c.candidate_id)
|
||||
|
||||
candidates = [
|
||||
Candidate(code="x", explanation="e", candidate_id="c1"),
|
||||
]
|
||||
_evaluate_batch_parallel(candidates, try_fn)
|
||||
assert ["c1"] == called
|
||||
|
||||
def test_multiple_candidates_all_evaluated(self) -> None:
|
||||
"""All candidates are evaluated when given multiple."""
|
||||
evaluated: set[str] = set()
|
||||
lock = threading.Lock()
|
||||
|
||||
def try_fn(c: Candidate) -> None:
|
||||
with lock:
|
||||
evaluated.add(c.candidate_id)
|
||||
|
||||
candidates = [
|
||||
Candidate(
|
||||
code="x",
|
||||
explanation="e",
|
||||
candidate_id=f"c{i}",
|
||||
)
|
||||
for i in range(6)
|
||||
]
|
||||
_evaluate_batch_parallel(candidates, try_fn)
|
||||
assert {f"c{i}" for i in range(6)} == evaluated
|
||||
|
||||
def test_uses_multiple_threads(self) -> None:
|
||||
"""Multiple candidates run on different threads."""
|
||||
thread_ids: set[int] = set()
|
||||
lock = threading.Lock()
|
||||
barrier = threading.Barrier(3, timeout=5)
|
||||
|
||||
def try_fn(c: Candidate) -> None:
|
||||
barrier.wait()
|
||||
with lock:
|
||||
thread_ids.add(threading.current_thread().ident or 0)
|
||||
|
||||
candidates = [
|
||||
Candidate(
|
||||
code="x",
|
||||
explanation="e",
|
||||
candidate_id=f"c{i}",
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
_evaluate_batch_parallel(candidates, try_fn, max_workers=3)
|
||||
assert len(thread_ids) >= 2
|
||||
|
||||
def test_exception_in_one_does_not_block_others(self) -> None:
|
||||
"""An exception in one candidate doesn't prevent others."""
|
||||
evaluated: set[str] = set()
|
||||
lock = threading.Lock()
|
||||
|
||||
def try_fn(c: Candidate) -> None:
|
||||
if c.candidate_id == "c1":
|
||||
msg = "boom"
|
||||
raise RuntimeError(msg)
|
||||
with lock:
|
||||
evaluated.add(c.candidate_id)
|
||||
|
||||
candidates = [
|
||||
Candidate(
|
||||
code="x",
|
||||
explanation="e",
|
||||
candidate_id=f"c{i}",
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
_evaluate_batch_parallel(candidates, try_fn)
|
||||
assert {"c0", "c2", "c3"} == evaluated
|
||||
|
||||
def test_concurrent_overlay_isolation(
|
||||
self,
|
||||
project: tuple[Path, Path],
|
||||
fn_input: Any,
|
||||
baseline: OriginalCodeBaseline,
|
||||
ctx: Any,
|
||||
test_files: TestFiles,
|
||||
) -> None:
|
||||
"""Multiple candidates evaluated in parallel don't corrupt each other."""
|
||||
eval_ctx = EvaluationContext()
|
||||
valid: list[Candidate] = []
|
||||
diff_lengths: list[int] = []
|
||||
_lock = threading.Lock()
|
||||
|
||||
candidates = [
|
||||
Candidate(
|
||||
code="def target_fn():\n return 42\n",
|
||||
explanation=f"opt{i}",
|
||||
candidate_id=f"c{i}",
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
|
||||
run_beh, run_bench, parse = _mock_behavioral_ok(
|
||||
baseline.behavior_test_results,
|
||||
bench_runtime=500_000,
|
||||
)
|
||||
|
||||
def _try_candidate(c: Candidate) -> None:
|
||||
sp = evaluate_candidate_isolated(
|
||||
candidate=c,
|
||||
fn_input=fn_input,
|
||||
baseline=baseline,
|
||||
eval_ctx=eval_ctx,
|
||||
test_files=test_files,
|
||||
test_env={},
|
||||
ctx=ctx,
|
||||
failed_candidate_code={},
|
||||
failed_candidate_diffs={},
|
||||
candidate_bench_results={},
|
||||
)
|
||||
if sp is not None and sp > 0:
|
||||
with _lock:
|
||||
valid.append(c)
|
||||
diff_lengths.append(0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_behavioral_tests",
|
||||
side_effect=run_beh,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".run_benchmarking_tests",
|
||||
side_effect=run_bench,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".parse_test_results",
|
||||
side_effect=parse,
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.pipeline._candidate_eval"
|
||||
".compare_test_results",
|
||||
return_value=(True, []),
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.verification._baseline"
|
||||
".add_async_perf_decorator",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"codeflash_python.verification._baseline"
|
||||
".revert_async_decorator",
|
||||
),
|
||||
):
|
||||
_evaluate_batch_parallel(candidates, _try_candidate)
|
||||
|
||||
assert 4 == len(valid)
|
||||
assert all(
|
||||
eval_ctx.is_correct[f"c{i}"] for i in range(4)
|
||||
)
|
||||
|
||||
_, mod = project
|
||||
assert "def target_fn():\n return 42\n" == mod.read_text(
|
||||
"utf-8",
|
||||
)
|
||||
Loading…
Reference in a new issue