487 lines
15 KiB
Python
487 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import attrs
|
|
import pytest
|
|
|
|
from codeflash_python._model import VerificationType
|
|
from codeflash_python.test_discovery.models import TestType
|
|
from codeflash_python.testing.models import (
|
|
FunctionTestInvocation,
|
|
InvocationId,
|
|
TestResults,
|
|
)
|
|
from codeflash_python.verification._verification import (
|
|
compare_test_results,
|
|
performance_gain,
|
|
)
|
|
from codeflash_python.verification.models import (
|
|
OptimizedCandidateResult,
|
|
TestDiff,
|
|
TestDiffScope,
|
|
)
|
|
|
|
|
|
def make_invocation( # noqa: PLR0913
|
|
*,
|
|
test_module: str = "test_module",
|
|
test_class: str | None = None,
|
|
test_function: str = "test_func",
|
|
target_function: str = "target_func",
|
|
iteration_id: str = "0",
|
|
loop_index: int = 1,
|
|
did_pass: bool = True,
|
|
runtime: int = 1000,
|
|
test_type: TestType = TestType.EXISTING_UNIT_TEST,
|
|
return_value: object | None = None,
|
|
timed_out: bool = False,
|
|
verification_type: str | None = VerificationType.FUNCTION_CALL,
|
|
stdout: str | None = None,
|
|
) -> FunctionTestInvocation:
|
|
"""Build a single FunctionTestInvocation."""
|
|
return FunctionTestInvocation(
|
|
loop_index=loop_index,
|
|
id=InvocationId(
|
|
test_module_path=test_module,
|
|
test_class_name=test_class,
|
|
test_function_name=test_function,
|
|
function_getting_tested=target_function,
|
|
iteration_id=iteration_id,
|
|
),
|
|
file_name=Path("/fake/test.py"),
|
|
did_pass=did_pass,
|
|
runtime=runtime,
|
|
test_framework="pytest",
|
|
test_type=test_type,
|
|
return_value=return_value,
|
|
timed_out=timed_out,
|
|
verification_type=verification_type,
|
|
stdout=stdout,
|
|
)
|
|
|
|
|
|
def make_results(
|
|
*invocations: FunctionTestInvocation,
|
|
) -> TestResults:
|
|
"""Build TestResults from invocations."""
|
|
results = TestResults()
|
|
for inv in invocations:
|
|
results.add(inv)
|
|
return results
|
|
|
|
|
|
class TestCompareTestResults:
|
|
"""compare_test_results behavioral equivalence comparison."""
|
|
|
|
def test_matching_results(self) -> None:
|
|
"""Identical results return (True, [])."""
|
|
inv = make_invocation(return_value=42)
|
|
original = make_results(inv)
|
|
candidate = make_results(inv)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is True
|
|
assert [] == diffs
|
|
|
|
def test_empty_original_returns_false(self) -> None:
|
|
"""Empty original results return (False, [])."""
|
|
original = TestResults()
|
|
candidate = make_results(make_invocation())
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
assert [] == diffs
|
|
|
|
def test_empty_candidate_returns_false(self) -> None:
|
|
"""Empty candidate results return (False, [])."""
|
|
original = make_results(make_invocation())
|
|
candidate = TestResults()
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
assert [] == diffs
|
|
|
|
def test_both_empty_returns_false(self) -> None:
|
|
"""Both empty results return (False, [])."""
|
|
match, diffs = compare_test_results(TestResults(), TestResults())
|
|
|
|
assert match is False
|
|
assert [] == diffs
|
|
|
|
def test_pass_fail_mismatch(self) -> None:
|
|
"""Original passes but candidate fails produces a DID_PASS diff."""
|
|
original = make_results(
|
|
make_invocation(did_pass=True, return_value=42),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(did_pass=False, return_value=42),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
assert 1 == len(diffs)
|
|
assert TestDiffScope.DID_PASS == diffs[0].scope
|
|
assert diffs[0].original_pass is True
|
|
assert diffs[0].candidate_pass is False
|
|
|
|
def test_return_value_mismatch(self) -> None:
|
|
"""Different return values produce RETURN_VALUE diff."""
|
|
original = make_results(
|
|
make_invocation(did_pass=True, return_value=42),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(did_pass=True, return_value=99),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
assert 1 == len(diffs)
|
|
assert TestDiffScope.RETURN_VALUE == diffs[0].scope
|
|
|
|
def test_stdout_mismatch(self) -> None:
|
|
"""Same return values but different stdout produces STDOUT diff."""
|
|
original = make_results(
|
|
make_invocation(
|
|
did_pass=True,
|
|
return_value=None,
|
|
stdout="hello",
|
|
),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(
|
|
did_pass=True,
|
|
return_value=None,
|
|
stdout="goodbye",
|
|
),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
assert 1 == len(diffs)
|
|
assert TestDiffScope.STDOUT == diffs[0].scope
|
|
|
|
def test_pass_fail_only_skips_return_values(self) -> None:
|
|
"""When pass_fail_only=True, return value diffs are ignored."""
|
|
original = make_results(
|
|
make_invocation(did_pass=True, return_value=42),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(did_pass=True, return_value=99),
|
|
)
|
|
|
|
match, diffs = compare_test_results(
|
|
original,
|
|
candidate,
|
|
pass_fail_only=True,
|
|
)
|
|
|
|
assert match is True
|
|
assert [] == diffs
|
|
|
|
def test_timed_out_tests_are_skipped(self) -> None:
|
|
"""Timed-out original tests are not compared."""
|
|
original = make_results(
|
|
make_invocation(timed_out=True, return_value=42),
|
|
make_invocation(
|
|
test_function="test_ok",
|
|
timed_out=False,
|
|
return_value=10,
|
|
),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(return_value=99),
|
|
make_invocation(
|
|
test_function="test_ok",
|
|
return_value=10,
|
|
),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is True
|
|
assert [] == diffs
|
|
|
|
def test_all_timed_out_returns_false(self) -> None:
|
|
"""If every original test timed out, returns (False, [])."""
|
|
original = make_results(
|
|
make_invocation(timed_out=True),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(return_value=99),
|
|
)
|
|
|
|
match, _diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
|
|
def test_candidate_extra_results_ignored(self) -> None:
|
|
"""Candidate has extra test IDs not in original, still matches."""
|
|
original = make_results(
|
|
make_invocation(return_value=42),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(return_value=42),
|
|
make_invocation(
|
|
test_function="test_extra",
|
|
return_value=100,
|
|
),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is True
|
|
assert [] == diffs
|
|
|
|
def test_helper_init_state_missing_in_candidate_ok(self) -> None:
|
|
"""INIT_STATE_HELPER verification type missing from candidate is ok."""
|
|
original = make_results(
|
|
make_invocation(
|
|
verification_type=VerificationType.INIT_STATE_HELPER,
|
|
return_value=42,
|
|
test_function="test_init_helper",
|
|
),
|
|
make_invocation(
|
|
test_function="test_normal",
|
|
return_value=10,
|
|
),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(
|
|
test_function="test_normal",
|
|
return_value=10,
|
|
),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is True
|
|
assert [] == diffs
|
|
|
|
def test_multiple_diffs_collected(self) -> None:
|
|
"""Multiple mismatches produce multiple TestDiff entries."""
|
|
original = make_results(
|
|
make_invocation(
|
|
test_function="test_a",
|
|
did_pass=True,
|
|
return_value=1,
|
|
),
|
|
make_invocation(
|
|
test_function="test_b",
|
|
did_pass=True,
|
|
return_value=2,
|
|
),
|
|
)
|
|
candidate = make_results(
|
|
make_invocation(
|
|
test_function="test_a",
|
|
did_pass=False,
|
|
return_value=1,
|
|
),
|
|
make_invocation(
|
|
test_function="test_b",
|
|
did_pass=True,
|
|
return_value=999,
|
|
),
|
|
)
|
|
|
|
match, diffs = compare_test_results(original, candidate)
|
|
|
|
assert match is False
|
|
assert 2 == len(diffs)
|
|
scopes = {d.scope for d in diffs}
|
|
assert TestDiffScope.DID_PASS in scopes
|
|
assert TestDiffScope.RETURN_VALUE in scopes
|
|
|
|
|
|
class TestPerformanceGain:
|
|
"""performance_gain speedup calculation."""
|
|
|
|
def test_faster_code(self) -> None:
|
|
"""original=1000, optimized=500 gives gain=1.0 (100% faster)."""
|
|
assert 1.0 == performance_gain(
|
|
original_runtime_ns=1000,
|
|
optimized_runtime_ns=500,
|
|
)
|
|
|
|
def test_same_speed(self) -> None:
|
|
"""original=1000, optimized=1000 gives gain=0.0."""
|
|
assert 0.0 == performance_gain(
|
|
original_runtime_ns=1000,
|
|
optimized_runtime_ns=1000,
|
|
)
|
|
|
|
def test_slower_code(self) -> None:
|
|
"""original=500, optimized=1000 gives negative gain."""
|
|
assert -0.5 == performance_gain(
|
|
original_runtime_ns=500,
|
|
optimized_runtime_ns=1000,
|
|
)
|
|
|
|
def test_zero_optimized_returns_zero(self) -> None:
|
|
"""optimized=0 returns gain=0.0."""
|
|
assert 0.0 == performance_gain(
|
|
original_runtime_ns=1000,
|
|
optimized_runtime_ns=0,
|
|
)
|
|
|
|
def test_large_speedup(self) -> None:
|
|
"""original=10000, optimized=100 gives gain=99.0."""
|
|
assert 99.0 == performance_gain(
|
|
original_runtime_ns=10000,
|
|
optimized_runtime_ns=100,
|
|
)
|
|
|
|
def test_marginal_improvement(self) -> None:
|
|
"""original=1000, optimized=999 gives small positive gain."""
|
|
result = performance_gain(
|
|
original_runtime_ns=1000,
|
|
optimized_runtime_ns=999,
|
|
)
|
|
assert result > 0.0
|
|
assert result < 0.01
|
|
|
|
|
|
class TestTestDiffScope:
|
|
"""TestDiffScope enum values."""
|
|
|
|
def test_values(self) -> None:
|
|
"""The three enum values exist with expected string values."""
|
|
assert "return_value" == TestDiffScope.RETURN_VALUE.value
|
|
assert "stdout" == TestDiffScope.STDOUT.value
|
|
assert "did_pass" == TestDiffScope.DID_PASS.value
|
|
|
|
|
|
class TestTestDiff:
|
|
"""TestDiff frozen data class."""
|
|
|
|
def test_construction(self) -> None:
|
|
"""Can construct with all fields."""
|
|
diff = TestDiff(
|
|
scope=TestDiffScope.RETURN_VALUE,
|
|
original_pass=True,
|
|
candidate_pass=True,
|
|
original_value="42",
|
|
candidate_value="99",
|
|
test_src_code="def test_foo(): ...",
|
|
candidate_pytest_error="AssertionError",
|
|
original_pytest_error=None,
|
|
)
|
|
|
|
assert TestDiffScope.RETURN_VALUE == diff.scope
|
|
assert diff.original_pass is True
|
|
assert diff.candidate_pass is True
|
|
assert "42" == diff.original_value
|
|
assert "99" == diff.candidate_value
|
|
assert "def test_foo(): ..." == diff.test_src_code
|
|
assert "AssertionError" == diff.candidate_pytest_error
|
|
assert diff.original_pytest_error is None
|
|
|
|
def test_frozen(self) -> None:
|
|
"""Raises on attribute assignment."""
|
|
diff = TestDiff(
|
|
scope=TestDiffScope.DID_PASS,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
)
|
|
|
|
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
|
diff.scope = TestDiffScope.STDOUT # type: ignore[misc]
|
|
|
|
def test_default_none_fields(self) -> None:
|
|
"""Optional fields default to None."""
|
|
diff = TestDiff(
|
|
scope=TestDiffScope.STDOUT,
|
|
original_pass=True,
|
|
candidate_pass=True,
|
|
)
|
|
|
|
assert diff.original_value is None
|
|
assert diff.candidate_value is None
|
|
assert diff.test_src_code is None
|
|
assert diff.candidate_pytest_error is None
|
|
assert diff.original_pytest_error is None
|
|
|
|
|
|
class TestOptimizedCandidateResult:
|
|
"""OptimizedCandidateResult frozen data class."""
|
|
|
|
def test_construction(self) -> None:
|
|
"""Can construct with all required fields."""
|
|
behavior = TestResults()
|
|
benchmarking = TestResults()
|
|
|
|
result = OptimizedCandidateResult(
|
|
max_loop_count=5,
|
|
best_test_runtime=1000,
|
|
behavior_test_results=behavior,
|
|
benchmarking_test_results=benchmarking,
|
|
optimization_candidate_index=0,
|
|
total_candidate_timing=5000,
|
|
)
|
|
|
|
assert 5 == result.max_loop_count
|
|
assert 1000 == result.best_test_runtime
|
|
|
|
def test_frozen(self) -> None:
|
|
"""Raises on attribute assignment."""
|
|
result = OptimizedCandidateResult(
|
|
max_loop_count=1,
|
|
best_test_runtime=100,
|
|
behavior_test_results=TestResults(),
|
|
benchmarking_test_results=TestResults(),
|
|
optimization_candidate_index=0,
|
|
total_candidate_timing=100,
|
|
)
|
|
|
|
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
|
result.max_loop_count = 99 # type: ignore[misc]
|
|
|
|
def test_field_access(self) -> None:
|
|
"""All fields are accessible."""
|
|
behavior = make_results(make_invocation())
|
|
benchmarking = TestResults()
|
|
|
|
result = OptimizedCandidateResult(
|
|
max_loop_count=10,
|
|
best_test_runtime=500,
|
|
behavior_test_results=behavior,
|
|
benchmarking_test_results=benchmarking,
|
|
optimization_candidate_index=2,
|
|
total_candidate_timing=3000,
|
|
)
|
|
|
|
assert 10 == result.max_loop_count
|
|
assert 500 == result.best_test_runtime
|
|
assert result.behavior_test_results is behavior
|
|
assert result.benchmarking_test_results is benchmarking
|
|
assert 2 == result.optimization_candidate_index
|
|
assert 3000 == result.total_candidate_timing
|
|
|
|
|
|
class TestGetAllUniqueInvocationLoopIds:
|
|
"""TestResults.get_all_unique_invocation_loop_ids."""
|
|
|
|
def test_returns_set_of_ids(self) -> None:
|
|
"""Returns correct set of unique invocation loop ids."""
|
|
inv_a = make_invocation(test_function="test_a")
|
|
inv_b = make_invocation(test_function="test_b")
|
|
results = make_results(inv_a, inv_b)
|
|
|
|
ids = results.get_all_unique_invocation_loop_ids()
|
|
|
|
assert 2 == len(ids)
|
|
assert inv_a.unique_invocation_loop_id in ids
|
|
assert inv_b.unique_invocation_loop_id in ids
|
|
|
|
def test_empty_results(self) -> None:
|
|
"""Empty results produce an empty set."""
|
|
results = TestResults()
|
|
|
|
assert set() == results.get_all_unique_invocation_loop_ids()
|