test: add 262 tests for previously untested core modules

- test_danom_result.py: 58 tests for Ok/Err Result monad
- test_danom_stream.py: 65 tests for Stream pipeline operations
- test_model.py: 57 tests for core data models and serialization
- test_pipeline.py: 59 tests for pipeline utilities and candidate evaluation
- test_normalizer.py: 23 tests for code normalization including SyntaxError handling
This commit is contained in:
Kevin Turcios 2026-04-24 01:36:14 -05:00
parent 90a46d732c
commit fd88580ac8
5 changed files with 2953 additions and 0 deletions

View file

@ -0,0 +1,490 @@
from __future__ import annotations
import attrs
import pytest
from codeflash_core.danom.result import Err, Ok, Result
class TestOk:
"""Tests for Ok variant of Result."""
def test_construction_with_value(self):
"""
Ok wraps the given value.
"""
ok = Ok(42)
assert 42 == ok.inner
def test_construction_default_none(self):
"""
Ok defaults to None when no value is given.
"""
ok = Ok()
assert ok.inner is None
def test_is_ok_returns_true(self):
"""
is_ok returns True for Ok.
"""
assert Ok(1).is_ok() is True
def test_unwrap_returns_inner(self):
"""
unwrap returns the wrapped value.
"""
assert "hello" == Ok("hello").unwrap()
def test_unwrap_none_value(self):
"""
unwrap returns None when Ok wraps None.
"""
assert Ok(None).unwrap() is None
def test_map_applies_func(self):
"""
map applies the function to the inner value and wraps in Ok.
"""
result = Ok(3).map(lambda x: x * 2)
assert Ok(6) == result
def test_map_with_extra_args(self):
"""
map passes extra positional and keyword args to the function.
"""
result = Ok(10).map(lambda x, y, z=0: x + y + z, 5, z=3)
assert Ok(18) == result
def test_map_err_is_noop(self):
"""
map_err returns self unchanged for Ok.
"""
ok = Ok(42)
result = ok.map_err(lambda e: "transformed")
assert result is ok
def test_and_then_returns_func_result(self):
"""
and_then applies the function and returns its raw result.
"""
result = Ok(5).and_then(lambda x: Ok(x + 1))
assert Ok(6) == result
def test_and_then_with_extra_args(self):
"""
and_then passes extra positional and keyword args to the function.
"""
result = Ok(1).and_then(lambda x, y, z=0: Ok(x + y + z), 2, z=3)
assert Ok(6) == result
def test_and_then_can_return_err(self):
"""
and_then can return an Err from the function.
"""
result = Ok(5).and_then(lambda x: Err("fail"))
assert Err("fail") == result
def test_or_else_is_noop(self):
"""
or_else returns self unchanged for Ok.
"""
ok = Ok(42)
result = ok.or_else(lambda e: Ok(0))
assert result is ok
class TestErr:
"""Tests for Err variant of Result."""
def test_construction_with_error(self):
"""
Err wraps the given error value.
"""
err = Err("bad")
assert "bad" == err.error
def test_construction_default_none(self):
"""
Err defaults to None error when no value is given.
"""
err = Err()
assert err.error is None
def test_is_ok_returns_false(self):
"""
is_ok returns False for Err.
"""
assert Err("x").is_ok() is False
def test_unwrap_raises_wrapped_exception(self):
"""
unwrap re-raises the wrapped Exception.
"""
exc = ValueError("boom")
with pytest.raises(ValueError, match="boom"):
Err(exc).unwrap()
def test_unwrap_non_exception_raises_valueerror(self):
"""
unwrap raises ValueError when the error is not an Exception.
"""
with pytest.raises(
ValueError, match="Err does not have a caught error"
):
Err("not an exception").unwrap()
def test_unwrap_none_error_raises_valueerror(self):
"""
unwrap raises ValueError when the error is None.
"""
with pytest.raises(
ValueError, match="Err does not have a caught error"
):
Err(None).unwrap()
def test_map_is_noop(self):
"""
map returns self unchanged for Err.
"""
err = Err("fail")
result = err.map(lambda x: x * 2)
assert result is err
def test_map_err_applies_func(self):
"""
map_err applies the function to the error value.
"""
result = Err("bad").map_err(lambda e: e.upper())
assert Err("BAD") == result
def test_map_err_with_extra_args(self):
"""
map_err passes extra positional and keyword args to the function.
"""
result = Err(10).map_err(lambda e, x, y=0: e + x + y, 5, y=3)
assert Err(18) == result
def test_and_then_is_noop(self):
"""
and_then returns self unchanged for Err.
"""
err = Err("fail")
result = err.and_then(lambda x: Ok(x + 1))
assert result is err
def test_or_else_applies_func(self):
"""
or_else applies the function to the error and returns its result.
"""
result = Err("fail").or_else(lambda e: Ok("recovered"))
assert Ok("recovered") == result
def test_or_else_with_extra_args(self):
"""
or_else passes extra positional and keyword args to the function.
"""
result = Err("x").or_else(
lambda e, suffix, sep="-": Ok(e + sep + suffix), "y", sep="+"
)
assert Ok("x+y") == result
def test_or_else_can_return_err(self):
"""
or_else can return another Err from the function.
"""
result = Err("first").or_else(lambda e: Err("second"))
assert Err("second") == result
def test_input_args_stored(self):
"""
input_args are stored on the Err instance.
"""
args = (("a", "b"), {"key": "val"})
err = Err("fail", input_args=args)
assert args == err.input_args
def test_input_args_default_empty_tuple(self):
"""
input_args defaults to an empty tuple.
"""
assert () == Err("fail").input_args
def test_traceback_stored(self):
"""
traceback string is stored on the Err instance.
"""
err = Err("fail", traceback="line 1\nline 2")
assert "line 1\nline 2" == err.traceback
def test_traceback_default_empty_string(self):
"""
traceback defaults to an empty string.
"""
assert "" == Err("fail").traceback
class TestErrEquality:
"""Tests for Err.__eq__ and __hash__."""
def test_equal_errs_with_same_exception_type_and_message(self):
"""
Two Errs with same exception type, message, and input_args are equal.
"""
assert Err(ValueError("x")) == Err(ValueError("x"))
def test_unequal_errs_different_message(self):
"""
Errs with different messages are not equal.
"""
assert Err(ValueError("x")) != Err(ValueError("y"))
def test_unequal_errs_different_type(self):
"""
Errs with different exception types are not equal.
"""
assert Err(ValueError("x")) != Err(TypeError("x"))
def test_unequal_errs_different_input_args(self):
"""
Errs with different input_args are not equal.
"""
err1 = Err("x", input_args=(("a",), {}))
err2 = Err("x", input_args=(("b",), {}))
assert err1 != err2
def test_err_not_equal_to_non_err(self):
"""
Err is not equal to a non-Err object.
"""
assert Err("x") != "x"
assert Err(42) != 42
def test_equal_errs_with_string_errors(self):
"""
Two Errs with identical string errors are equal.
"""
assert Err("fail") == Err("fail")
def test_hash_consistent_with_equality(self):
"""
Equal Err instances produce the same hash.
"""
err1 = Err(ValueError("x"))
err2 = Err(ValueError("x"))
assert hash(err1) == hash(err2)
def test_hash_differs_for_unequal(self):
"""
Different Err instances are unlikely to share a hash.
"""
assert hash(Err("a")) != hash(Err("b"))
class TestOkEquality:
"""Tests for Ok equality and hashing (attrs-generated)."""
def test_equal_ok_values(self):
"""
Two Ok instances with the same inner value are equal.
"""
assert Ok(42) == Ok(42)
def test_unequal_ok_values(self):
"""
Two Ok instances with different inner values are not equal.
"""
assert Ok(1) != Ok(2)
def test_ok_not_equal_to_err(self):
"""
Ok and Err are never equal.
"""
assert Ok(1) != Err(1)
def test_ok_not_equal_to_plain_value(self):
"""
Ok is not equal to a plain unwrapped value.
"""
assert Ok(42) != 42
def test_hash_consistent_with_equality(self):
"""
Equal Ok instances produce the same hash.
"""
assert hash(Ok("a")) == hash(Ok("a"))
def test_ok_usable_in_sets(self):
"""
Ok instances can be stored in a set.
"""
s = {Ok(1), Ok(1), Ok(2)}
assert 2 == len(s)
class TestResultStaticMethods:
"""Tests for Result.unit, result_is_ok, and result_unwrap."""
def test_unit_wraps_in_ok(self):
"""
Result.unit wraps a value in Ok.
"""
result = Result.unit(99)
assert Ok(99) == result
def test_result_is_ok_for_ok(self):
"""
result_is_ok returns True for an Ok.
"""
assert Result.result_is_ok(Ok(1)) is True
def test_result_is_ok_for_err(self):
"""
result_is_ok returns False for an Err.
"""
assert Result.result_is_ok(Err("fail")) is False
def test_result_unwrap_ok(self):
"""
result_unwrap returns the inner value for Ok.
"""
assert 42 == Result.result_unwrap(Ok(42))
def test_result_unwrap_err_raises(self):
"""
result_unwrap raises for an Err with an Exception.
"""
with pytest.raises(RuntimeError, match="boom"):
Result.result_unwrap(Err(RuntimeError("boom")))
class TestNestedResult:
"""Tests for nested Result values."""
def test_ok_wrapping_ok(self):
"""
Ok can wrap another Ok.
"""
nested = Ok(Ok(42))
assert Ok(42) == nested.unwrap()
assert 42 == nested.unwrap().unwrap()
def test_ok_wrapping_err(self):
"""
Ok can wrap an Err.
"""
nested = Ok(Err("inner"))
inner = nested.unwrap()
assert False is inner.is_ok()
assert "inner" == inner.error
def test_map_on_nested_ok(self):
"""
map on a nested Ok applies the function to the inner Result.
"""
nested = Ok(Ok(5))
result = nested.map(lambda inner_ok: inner_ok.map(lambda x: x * 2))
assert Ok(Ok(10)) == result
class TestErrTracebackExtraction:
"""Tests for Err traceback detail extraction."""
def test_details_populated_from_exception_traceback(self):
"""
details list is populated when error is an Exception with a traceback.
"""
try:
msg = "test error"
raise ValueError(msg) # noqa: TRY301
except ValueError as exc:
err = Err(exc)
assert len(err.details) > 0
assert "file" in err.details[0]
assert "func" in err.details[0]
assert "line_no" in err.details[0]
assert "locals" in err.details[0]
def test_details_empty_for_non_exception(self):
"""
details list is empty when error is not an Exception.
"""
err = Err("just a string")
assert [] == err.details
def test_details_empty_for_exception_without_traceback(self):
"""
details list is empty for an Exception created without raising.
"""
err = Err(ValueError("no traceback"))
assert [] == err.details
class TestErrImmutability:
"""Tests for frozen behavior of Ok and Err."""
def test_ok_is_frozen(self):
"""
Ok instances cannot be mutated.
"""
ok = Ok(42)
with pytest.raises(attrs.exceptions.FrozenInstanceError):
ok.inner = 99 # type: ignore[misc]
def test_err_is_frozen(self):
"""
Err instances cannot be mutated.
"""
err = Err("fail")
with pytest.raises(attrs.exceptions.FrozenInstanceError):
err.error = "other" # type: ignore[misc]
class TestErrValidators:
"""Tests for Err field validators."""
def test_input_args_must_be_tuple(self):
"""
input_args rejects non-tuple values.
"""
with pytest.raises(TypeError):
Err("fail", input_args=["not", "a", "tuple"]) # type: ignore[arg-type]
def test_traceback_must_be_string(self):
"""
traceback rejects non-string values.
"""
with pytest.raises(TypeError):
Err("fail", traceback=123) # type: ignore[arg-type]

View file

@ -0,0 +1,761 @@
from __future__ import annotations
import asyncio
import attrs
import pytest
from codeflash_core.danom.stream import (
_FILTER,
_MAP,
_TAP,
Stream,
_apply_fns,
_Nothing,
_par_apply_fns,
)
class TestStreamFromIterable:
"""Tests for Stream.from_iterable."""
def test_from_list(self):
"""
A list is converted into a Stream with a tuple seq.
"""
s = Stream.from_iterable([1, 2, 3])
assert (1, 2, 3) == s.seq
def test_from_tuple(self):
"""
A tuple is accepted as an iterable.
"""
s = Stream.from_iterable((4, 5))
assert (4, 5) == s.seq
def test_from_generator(self):
"""
A generator is consumed and stored as a tuple.
"""
s = Stream.from_iterable(x * 2 for x in range(3))
assert (0, 2, 4) == s.seq
def test_from_empty(self):
"""
An empty iterable produces an empty stream.
"""
s = Stream.from_iterable([])
assert () == s.seq
def test_from_string(self):
"""
A string iterable yields individual characters.
"""
s = Stream.from_iterable("abc")
assert ("a", "b", "c") == s.seq
def test_from_range(self):
"""
A range object is accepted as an iterable.
"""
s = Stream.from_iterable(range(4))
assert (0, 1, 2, 3) == s.seq
def test_ops_default_empty(self):
"""
A fresh stream has no queued operations.
"""
s = Stream.from_iterable([1])
assert () == s.ops
class TestStreamBool:
"""Tests for Stream.__bool__."""
def test_truthy_when_nonempty(self):
"""
A stream with elements is truthy.
"""
s = Stream.from_iterable([1])
assert bool(s) is True
def test_falsy_when_empty(self):
"""
An empty stream is falsy.
"""
s = Stream.from_iterable([])
assert bool(s) is False
class TestStreamMap:
"""Tests for Stream.map."""
def test_single_map(self):
"""
A single map function is applied to each element.
"""
result = (
Stream.from_iterable([1, 2, 3]).map(lambda x: x * 10).collect()
)
assert (10, 20, 30) == result
def test_multiple_fns_in_one_call(self):
"""
Multiple functions passed to a single map call are applied sequentially.
"""
result = (
Stream.from_iterable([1, 2])
.map(lambda x: x + 1, lambda x: x * 3)
.collect()
)
assert (6, 9) == result
def test_map_empty_stream(self):
"""
Mapping over an empty stream produces an empty tuple.
"""
result = Stream.from_iterable([]).map(lambda x: x + 1).collect()
assert () == result
def test_map_preserves_immutability(self):
"""
Calling map returns a new stream without mutating the original.
"""
s1 = Stream.from_iterable([1, 2])
s2 = s1.map(lambda x: x * 2)
assert () == s1.ops
assert 1 == len(s2.ops)
assert s1 is not s2
class TestStreamFilter:
"""Tests for Stream.filter."""
def test_single_filter(self):
"""
A single filter predicate removes non-matching elements.
"""
result = (
Stream.from_iterable([1, 2, 3, 4])
.filter(lambda x: x % 2 == 0)
.collect()
)
assert (2, 4) == result
def test_filter_removes_all(self):
"""
A predicate that matches nothing produces an empty tuple.
"""
result = (
Stream.from_iterable([1, 3, 5]).filter(lambda x: x > 100).collect()
)
assert () == result
def test_filter_keeps_all(self):
"""
A predicate that matches everything keeps all elements.
"""
result = (
Stream.from_iterable([2, 4, 6])
.filter(lambda x: x % 2 == 0)
.collect()
)
assert (2, 4, 6) == result
def test_multiple_filters(self):
"""
Multiple filter predicates in one call are applied sequentially.
"""
result = (
Stream.from_iterable(range(20))
.filter(lambda x: x % 2 == 0, lambda x: x > 5)
.collect()
)
assert (6, 8, 10, 12, 14, 16, 18) == result
def test_filter_empty_stream(self):
"""
Filtering an empty stream produces an empty tuple.
"""
result = Stream.from_iterable([]).filter(lambda x: True).collect()
assert () == result
class TestStreamTap:
"""Tests for Stream.tap."""
def test_tap_side_effect(self):
"""
Tap executes a side effect without altering elements.
"""
seen = []
result = Stream.from_iterable([1, 2, 3]).tap(seen.append).collect()
assert (1, 2, 3) == result
assert [1, 2, 3] == seen
def test_tap_receives_deepcopy(self):
"""
Tap receives a deep copy so mutations do not affect the stream.
"""
originals = [{"a": 1}, {"a": 2}]
captured = []
def mutating_tap(x):
x["a"] = 999
captured.append(x)
result = Stream.from_iterable(originals).tap(mutating_tap).collect()
assert ({"a": 1}, {"a": 2}) == result
assert [{"a": 999}, {"a": 999}] == captured
class TestStreamChaining:
"""Tests for chaining multiple operations."""
def test_map_then_filter(self):
"""
Map followed by filter applies both in order.
"""
result = (
Stream.from_iterable([1, 2, 3, 4, 5])
.map(lambda x: x * 2)
.filter(lambda x: x > 4)
.collect()
)
assert (6, 8, 10) == result
def test_filter_then_map(self):
"""
Filter followed by map applies both in order.
"""
result = (
Stream.from_iterable([1, 2, 3, 4, 5])
.filter(lambda x: x > 2)
.map(lambda x: x**2)
.collect()
)
assert (9, 16, 25) == result
def test_map_filter_tap_chain(self):
"""
All three operation types can be chained together.
"""
tapped = []
result = (
Stream.from_iterable([1, 2, 3, 4])
.map(lambda x: x + 10)
.filter(lambda x: x % 2 == 0)
.tap(tapped.append)
.collect()
)
assert (12, 14) == result
assert [12, 14] == tapped
def test_multiple_map_calls(self):
"""
Multiple chained map calls compose sequentially.
"""
result = (
Stream.from_iterable([1, 2])
.map(lambda x: x + 1)
.map(lambda x: x * 10)
.collect()
)
assert (20, 30) == result
class TestStreamFold:
"""Tests for Stream.fold."""
def test_sum(self):
"""
Fold with addition accumulates a sum.
"""
result = Stream.from_iterable([1, 2, 3, 4]).fold(
0, lambda acc, x: acc + x
)
assert 10 == result
def test_fold_with_operations(self):
"""
Fold applies queued operations before reducing.
"""
result = (
Stream.from_iterable([1, 2, 3, 4])
.filter(lambda x: x % 2 == 0)
.map(lambda x: x * 10)
.fold(0, lambda acc, x: acc + x)
)
assert 60 == result
def test_fold_empty_stream(self):
"""
Folding an empty stream returns the initial value.
"""
result = Stream.from_iterable([]).fold(42, lambda acc, x: acc + x)
assert 42 == result
def test_fold_string_concat(self):
"""
Fold can concatenate strings.
"""
result = Stream.from_iterable(["a", "b", "c"]).fold(
"", lambda acc, x: acc + x
)
assert "abc" == result
class TestStreamPartition:
"""Tests for Stream.partition."""
def test_basic_partition(self):
"""
Partition splits elements by the predicate into two streams.
"""
truthy, falsy = Stream.from_iterable([1, 2, 3, 4, 5]).partition(
lambda x: x > 3
)
assert (4, 5) == truthy.collect()
assert (1, 2, 3) == falsy.collect()
def test_partition_all_true(self):
"""
When all elements match, the false stream is empty.
"""
truthy, falsy = Stream.from_iterable([2, 4, 6]).partition(
lambda x: x % 2 == 0
)
assert (2, 4, 6) == truthy.collect()
assert () == falsy.collect()
def test_partition_all_false(self):
"""
When no elements match, the true stream is empty.
"""
truthy, falsy = Stream.from_iterable([1, 3, 5]).partition(
lambda x: x % 2 == 0
)
assert () == truthy.collect()
assert (1, 3, 5) == falsy.collect()
def test_partition_empty(self):
"""
Partitioning an empty stream yields two empty streams.
"""
truthy, falsy = Stream.from_iterable([]).partition(lambda x: x > 0)
assert () == truthy.collect()
assert () == falsy.collect()
def test_partition_applies_pending_ops(self):
"""
Partition materializes pending operations before splitting.
"""
truthy, falsy = (
Stream.from_iterable([1, 2, 3, 4])
.map(lambda x: x * 10)
.partition(lambda x: x >= 30)
)
assert (30, 40) == truthy.collect()
assert (10, 20) == falsy.collect()
class TestStreamCollect:
"""Tests for Stream.collect."""
def test_collect_no_ops(self):
"""
Collecting with no operations returns the original elements.
"""
result = Stream.from_iterable([1, 2, 3]).collect()
assert (1, 2, 3) == result
def test_collect_returns_tuple(self):
"""
Collect always returns a tuple.
"""
result = Stream.from_iterable([1]).collect()
assert isinstance(result, tuple)
def test_collect_single_element(self):
"""
A single-element stream collects to a one-element tuple.
"""
result = Stream.from_iterable([42]).map(lambda x: x + 1).collect()
assert (43,) == result
class TestStreamParCollect:
"""Tests for Stream.par_collect."""
def test_par_collect_threads(self):
"""
Parallel collection with threads produces the same result as sequential.
"""
result = (
Stream.from_iterable(range(20))
.map(lambda x: x * 2)
.par_collect(workers=2, use_threads=True)
)
assert tuple(x * 2 for x in range(20)) == result
def test_par_collect_empty(self):
"""
Parallel collection of an empty stream returns an empty tuple.
"""
result = Stream.from_iterable([]).par_collect(
workers=2, use_threads=True
)
assert () == result
def test_par_collect_with_filter(self):
"""
Parallel collection respects filter operations.
"""
result = (
Stream.from_iterable(range(10))
.filter(lambda x: x % 2 == 0)
.par_collect(workers=2, use_threads=True)
)
assert (0, 2, 4, 6, 8) == result
class TestStreamAsyncCollect:
"""Tests for Stream.async_collect."""
@pytest.mark.asyncio
async def test_async_map(self):
"""
Async map functions are awaited during collection.
"""
async def times_ten(x):
return x * 10
result = await (
Stream.from_iterable([1, 2, 3]).map(times_ten).async_collect()
)
assert (10, 20, 30) == result
@pytest.mark.asyncio
async def test_async_collect_no_ops(self):
"""
Async collect with no ops falls back to synchronous collect.
"""
result = await Stream.from_iterable([1, 2]).async_collect()
assert (1, 2) == result
@pytest.mark.asyncio
async def test_async_filter(self):
"""
Async filter predicates remove non-matching elements.
"""
async def gt_two(x):
return x > 2
result = await (
Stream.from_iterable([1, 2, 3, 4]).filter(gt_two).async_collect()
)
assert (3, 4) == result
@pytest.mark.asyncio
async def test_async_tap(self):
"""
Async tap executes side effects without altering elements.
"""
seen = []
async def record(x):
seen.append(x)
result = await (
Stream.from_iterable([10, 20]).tap(record).async_collect()
)
assert (10, 20) == result
assert [10, 20] == seen
class TestApplyFns:
"""Tests for the _apply_fns helper."""
def test_map_op(self):
"""
MAP operations transform elements.
"""
ops = ((_MAP, lambda x: x + 1),)
result = tuple(_apply_fns((1, 2, 3), ops))
assert (2, 3, 4) == result
def test_filter_op(self):
"""
FILTER operations remove non-matching elements.
"""
ops = ((_FILTER, lambda x: x > 2),)
result = tuple(_apply_fns((1, 2, 3, 4), ops))
assert (3, 4) == result
def test_tap_op(self):
"""
TAP operations run side effects and pass elements through.
"""
seen = []
ops = ((_TAP, seen.append),)
result = tuple(_apply_fns((1, 2), ops))
assert (1, 2) == result
assert 2 == len(seen)
def test_combined_ops(self):
"""
Mixed MAP, FILTER, and TAP ops are applied in sequence.
"""
tapped = []
ops = (
(_MAP, lambda x: x * 2),
(_FILTER, lambda x: x > 4),
(_TAP, tapped.append),
)
result = tuple(_apply_fns((1, 2, 3, 4), ops))
assert (6, 8) == result
assert [6, 8] == tapped
def test_empty_ops(self):
"""
No operations yields elements unchanged.
"""
result = tuple(_apply_fns((1, 2), ()))
assert (1, 2) == result
def test_empty_elements(self):
"""
Empty elements yields nothing regardless of operations.
"""
ops = ((_MAP, lambda x: x + 1),)
result = tuple(_apply_fns((), ops))
assert () == result
def test_filter_short_circuits(self):
"""
Once a filter rejects an element, subsequent ops are skipped.
"""
map_calls = []
def tracking_map(x):
map_calls.append(x)
return x
ops = (
(_FILTER, lambda x: x > 5),
(_MAP, tracking_map),
)
tuple(_apply_fns((1, 10), ops))
assert [10] == map_calls
class TestParApplyFns:
"""Tests for the _par_apply_fns helper."""
def test_returns_tuple(self):
"""
_par_apply_fns returns a tuple, not a generator.
"""
result = _par_apply_fns((1, 2, 3), ((_MAP, lambda x: x + 1),))
assert isinstance(result, tuple)
assert (2, 3, 4) == result
def test_empty_elements(self):
"""
Empty elements produce an empty tuple.
"""
result = _par_apply_fns((), ((_MAP, lambda x: x),))
assert () == result
def test_filter_and_map(self):
"""
Combined filter and map work in the eager variant.
"""
ops = (
(_FILTER, lambda x: x % 2 == 0),
(_MAP, lambda x: x * 10),
)
result = _par_apply_fns((1, 2, 3, 4), ops)
assert (20, 40) == result
class TestNothing:
"""Tests for the _Nothing sentinel."""
def test_nothing_is_singleton(self):
"""
NOTHING is the sole member of the _Nothing enum.
"""
assert 1 == len(_Nothing)
assert _Nothing.NOTHING is _Nothing.NOTHING
def test_nothing_is_not_equal_to_none(self):
"""
NOTHING is distinct from None.
"""
assert _Nothing.NOTHING != None # noqa: E711
class TestStreamImmutability:
"""Tests for Stream's frozen/immutable behavior."""
def test_cannot_set_seq(self):
"""
Assigning to seq on a frozen stream raises an error.
"""
s = Stream.from_iterable([1, 2])
with pytest.raises(attrs.exceptions.FrozenInstanceError):
s.seq = (3, 4)
def test_cannot_set_ops(self):
"""
Assigning to ops on a frozen stream raises an error.
"""
s = Stream.from_iterable([1, 2])
with pytest.raises(attrs.exceptions.FrozenInstanceError):
s.ops = ()
def test_seq_must_be_tuple(self):
"""
Constructing a stream with a non-tuple seq raises TypeError.
"""
with pytest.raises(TypeError):
Stream(seq=[1, 2, 3])
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_single_element_map(self):
"""
A stream with one element processes map correctly.
"""
result = Stream.from_iterable([5]).map(lambda x: x * 3).collect()
assert (15,) == result
def test_single_element_filter_keeps(self):
"""
A single element matching the filter is kept.
"""
result = Stream.from_iterable([5]).filter(lambda x: x > 0).collect()
assert (5,) == result
def test_single_element_filter_removes(self):
"""
A single element not matching the filter is removed.
"""
result = Stream.from_iterable([5]).filter(lambda x: x > 10).collect()
assert () == result
def test_none_values_in_stream(self):
"""
None values are valid stream elements.
"""
result = Stream.from_iterable([None, 1, None]).collect()
assert (None, 1, None) == result
def test_nested_streams(self):
"""
Streams can contain other streams as elements.
"""
inner = Stream.from_iterable([1, 2])
outer = Stream.from_iterable([inner, inner])
result = outer.collect()
assert 2 == len(result)
assert (1, 2) == result[0].collect()
def test_large_stream(self):
"""
A large stream processes without error.
"""
result = (
Stream.from_iterable(range(10000))
.map(lambda x: x + 1)
.filter(lambda x: x % 1000 == 0)
.collect()
)
assert (
1000,
2000,
3000,
4000,
5000,
6000,
7000,
8000,
9000,
10000,
) == result
def test_fold_with_workers_threads(self):
"""
Fold with workers>1 and use_threads produces the same result.
"""
result = Stream.from_iterable(range(100)).fold(
0, lambda acc, x: acc + x, workers=2, use_threads=True
)
assert 4950 == result

View file

@ -0,0 +1,665 @@
from __future__ import annotations
import attrs
import pytest
from codeflash_core._model import (
_NS_PER_DAY,
_NS_PER_HOUR,
_NS_PER_MIN,
_NS_PER_MS,
_NS_PER_S,
_NS_PER_US,
BenchmarkDetail,
Candidate,
FileDiffContent,
OptimizationRequest,
OptimizationReviewResult,
PrComment,
humanize_runtime,
)
class TestTimeConstants:
"""Tests for nanosecond conversion constants."""
def test_ns_per_us(self):
"""
One microsecond equals 1,000 nanoseconds.
"""
assert 1_000 == _NS_PER_US
def test_ns_per_ms(self):
"""
One millisecond equals 1,000,000 nanoseconds.
"""
assert 1_000_000 == _NS_PER_MS
def test_ns_per_s(self):
"""
One second equals 1,000,000,000 nanoseconds.
"""
assert 1_000_000_000 == _NS_PER_S
def test_ns_per_min(self):
"""
One minute equals 60,000,000,000 nanoseconds.
"""
assert 60_000_000_000 == _NS_PER_MIN
def test_ns_per_hour(self):
"""
One hour equals 3,600,000,000,000 nanoseconds.
"""
assert 3_600_000_000_000 == _NS_PER_HOUR
def test_ns_per_day(self):
"""
One day equals 86,400,000,000,000 nanoseconds.
"""
assert 86_400_000_000_000 == _NS_PER_DAY
def test_constants_are_consistent(self):
"""
Each constant is derivable from the previous one.
"""
assert _NS_PER_MS == _NS_PER_US * 1_000
assert _NS_PER_S == _NS_PER_MS * 1_000
assert _NS_PER_MIN == _NS_PER_S * 60
assert _NS_PER_HOUR == _NS_PER_MIN * 60
assert _NS_PER_DAY == _NS_PER_HOUR * 24
class TestOptimizationRequest:
"""Tests for OptimizationRequest."""
def test_required_fields(self):
"""
Construction with only required fields succeeds.
"""
req = OptimizationRequest(
source_code="def f(): pass",
language="python",
language_version="3.12",
)
assert "def f(): pass" == req.source_code
assert "python" == req.language
assert "3.12" == req.language_version
def test_default_values(self):
"""
Optional fields use their documented defaults.
"""
req = OptimizationRequest(
source_code="x",
language="python",
language_version="3.12",
)
assert "" == req.context_code
assert req.is_async is False
assert req.is_numerical_code is None
assert "" == req.codeflash_version
assert req.baseline_runtime_ns is None
assert req.loop_count is None
assert req.line_profiler_results is None
assert req.test_input_examples is None
def test_all_fields(self):
"""
All fields can be set explicitly.
"""
req = OptimizationRequest(
source_code="def f(): pass",
language="javascript",
language_version="ES2022",
context_code="import os",
is_async=True,
is_numerical_code=True,
codeflash_version="1.0.0",
baseline_runtime_ns=5000,
loop_count=100,
line_profiler_results="Line # Hits",
test_input_examples="example()",
)
assert "javascript" == req.language
assert "ES2022" == req.language_version
assert "import os" == req.context_code
assert req.is_async is True
assert req.is_numerical_code is True
assert "1.0.0" == req.codeflash_version
assert 5000 == req.baseline_runtime_ns
assert 100 == req.loop_count
assert "Line # Hits" == req.line_profiler_results
assert "example()" == req.test_input_examples
def test_frozen(self):
"""
Mutating a field raises FrozenInstanceError.
"""
req = OptimizationRequest(
source_code="x",
language="python",
language_version="3.12",
)
with pytest.raises(attrs.exceptions.FrozenInstanceError):
req.source_code = "y"
def test_empty_source_code(self):
"""
An empty source_code string is allowed.
"""
req = OptimizationRequest(
source_code="",
language="python",
language_version="3.12",
)
assert "" == req.source_code
class TestCandidate:
"""Tests for Candidate."""
def test_required_fields(self):
"""
Construction with only required fields succeeds.
"""
c = Candidate(code="def f(): pass", explanation="simplified")
assert "def f(): pass" == c.code
assert "simplified" == c.explanation
def test_default_values(self):
"""
Optional fields default to empty strings.
"""
c = Candidate(code="x", explanation="y")
assert "" == c.candidate_id
assert "" == c.source
assert "" == c.parent_id
assert "" == c.code_markdown
def test_all_fields(self):
"""
All fields can be set explicitly.
"""
c = Candidate(
code="def f(): pass",
explanation="optimized",
candidate_id="abc-123",
source="ai-service",
parent_id="parent-0",
code_markdown="```python\ndef f(): pass\n```",
)
assert "abc-123" == c.candidate_id
assert "ai-service" == c.source
assert "parent-0" == c.parent_id
assert "```python\ndef f(): pass\n```" == c.code_markdown
def test_frozen(self):
"""
Mutating a field raises FrozenInstanceError.
"""
c = Candidate(code="x", explanation="y")
with pytest.raises(attrs.exceptions.FrozenInstanceError):
c.code = "z"
def test_empty_strings(self):
"""
Empty strings are valid for required fields.
"""
c = Candidate(code="", explanation="")
assert "" == c.code
assert "" == c.explanation
class TestOptimizationReviewResult:
"""Tests for OptimizationReviewResult."""
def test_construction(self):
"""
Both fields are stored correctly.
"""
r = OptimizationReviewResult(
review="high",
explanation="Well-tested optimization.",
)
assert "high" == r.review
assert "Well-tested optimization." == r.explanation
def test_frozen(self):
"""
Mutating a field raises FrozenInstanceError.
"""
r = OptimizationReviewResult(review="low", explanation="risky")
with pytest.raises(attrs.exceptions.FrozenInstanceError):
r.review = "high"
def test_empty_strings(self):
"""
Empty strings are valid for both fields.
"""
r = OptimizationReviewResult(review="", explanation="")
assert "" == r.review
assert "" == r.explanation
class TestFileDiffContent:
"""Tests for FileDiffContent."""
def test_construction(self):
"""
Both old and new content are stored.
"""
d = FileDiffContent(old_content="before", new_content="after")
assert "before" == d.old_content
assert "after" == d.new_content
def test_frozen(self):
"""
Mutating a field raises FrozenInstanceError.
"""
d = FileDiffContent(old_content="a", new_content="b")
with pytest.raises(attrs.exceptions.FrozenInstanceError):
d.old_content = "c"
def test_empty_strings(self):
"""
Empty strings are valid for both fields.
"""
d = FileDiffContent(old_content="", new_content="")
assert "" == d.old_content
assert "" == d.new_content
class TestBenchmarkDetail:
"""Tests for BenchmarkDetail."""
@pytest.fixture(name="detail")
def _detail(self):
"""
A sample BenchmarkDetail for testing.
"""
return BenchmarkDetail(
benchmark_name="test_suite",
test_function="test_compute",
original_timing="100ms",
expected_new_timing="50ms",
speedup_percent=50.0,
)
def test_construction(self, detail):
"""
All fields are stored correctly.
"""
assert "test_suite" == detail.benchmark_name
assert "test_compute" == detail.test_function
assert "100ms" == detail.original_timing
assert "50ms" == detail.expected_new_timing
assert 50.0 == detail.speedup_percent
def test_frozen(self, detail):
"""
Mutating a field raises FrozenInstanceError.
"""
with pytest.raises(attrs.exceptions.FrozenInstanceError):
detail.benchmark_name = "other"
def test_to_string(self, detail):
"""
to_string returns a multi-line summary with the correct format.
"""
result = detail.to_string()
expected = (
"Original timing for test_suite::test_compute: 100ms\n"
"Expected new timing for test_suite::test_compute: 50ms\n"
"Benchmark speedup for test_suite::test_compute: 50.00%\n"
)
assert expected == result
def test_to_string_fractional_speedup(self):
"""
to_string formats fractional speedup percentages to two decimals.
"""
d = BenchmarkDetail(
benchmark_name="bench",
test_function="test_fn",
original_timing="200ns",
expected_new_timing="133ns",
speedup_percent=33.333,
)
assert "33.33%" in d.to_string()
def test_to_dict(self, detail):
"""
to_dict returns all fields as a plain dictionary.
"""
expected = {
"benchmark_name": "test_suite",
"test_function": "test_compute",
"original_timing": "100ms",
"expected_new_timing": "50ms",
"speedup_percent": 50.0,
}
assert expected == detail.to_dict()
def test_to_dict_roundtrip(self, detail):
"""
A BenchmarkDetail can be reconstructed from its to_dict output.
"""
d = detail.to_dict()
reconstructed = BenchmarkDetail(**d)
assert detail == reconstructed
def test_empty_strings(self):
"""
Empty strings are valid for name and timing fields.
"""
d = BenchmarkDetail(
benchmark_name="",
test_function="",
original_timing="",
expected_new_timing="",
speedup_percent=0.0,
)
assert "" == d.benchmark_name
assert "" == d.test_function
class TestPrComment:
"""Tests for PrComment."""
@pytest.fixture(name="pr_comment")
def _pr_comment(self):
"""
A sample PrComment for testing.
"""
return PrComment(
optimization_explanation="Replaced loop with vectorized op.",
best_runtime=50_000_000,
original_runtime=100_000_000,
function_name="compute",
relative_file_path="src/compute.py",
speedup_x="2.0x",
speedup_pct="50%",
loop_count=10,
report_table={"test_a": {"before": 100, "after": 50}},
)
def test_construction(self, pr_comment):
"""
All required fields are stored correctly.
"""
assert "Replaced loop with vectorized op." == (
pr_comment.optimization_explanation
)
assert 50_000_000 == pr_comment.best_runtime
assert 100_000_000 == pr_comment.original_runtime
assert "compute" == pr_comment.function_name
assert "src/compute.py" == pr_comment.relative_file_path
assert "2.0x" == pr_comment.speedup_x
assert "50%" == pr_comment.speedup_pct
assert 10 == pr_comment.loop_count
assert {"test_a": {"before": 100, "after": 50}} == (
pr_comment.report_table
)
def test_optional_defaults(self, pr_comment):
"""
Optional fields default to None.
"""
assert pr_comment.benchmark_details is None
assert pr_comment.original_async_throughput is None
assert pr_comment.best_async_throughput is None
def test_frozen(self, pr_comment):
"""
Mutating a field raises FrozenInstanceError.
"""
with pytest.raises(attrs.exceptions.FrozenInstanceError):
pr_comment.function_name = "other"
def test_to_json_keys(self, pr_comment):
"""
to_json returns the expected set of keys.
"""
result = pr_comment.to_json()
expected_keys = {
"optimization_explanation",
"best_runtime",
"original_runtime",
"function_name",
"file_path",
"speedup_x",
"speedup_pct",
"loop_count",
"report_table",
"benchmark_details",
}
assert expected_keys == set(result.keys())
def test_to_json_humanizes_runtimes(self, pr_comment):
"""
to_json converts runtimes via humanize_runtime.
"""
result = pr_comment.to_json()
assert humanize_runtime(50_000_000) == result["best_runtime"]
assert humanize_runtime(100_000_000) == result["original_runtime"]
def test_to_json_file_path_key(self, pr_comment):
"""
to_json maps relative_file_path to the 'file_path' key.
"""
result = pr_comment.to_json()
assert "src/compute.py" == result["file_path"]
def test_to_json_no_benchmark_details(self, pr_comment):
"""
to_json sets benchmark_details to None when unset.
"""
result = pr_comment.to_json()
assert result["benchmark_details"] is None
def test_to_json_with_benchmark_details(self):
"""
to_json includes benchmark_details when provided.
"""
bd = BenchmarkDetail(
benchmark_name="suite",
test_function="test_fn",
original_timing="10ms",
expected_new_timing="5ms",
speedup_percent=50.0,
)
pc = PrComment(
optimization_explanation="optimized",
best_runtime=5_000_000,
original_runtime=10_000_000,
function_name="fn",
relative_file_path="f.py",
speedup_x="2x",
speedup_pct="50%",
loop_count=1,
report_table={},
benchmark_details=(bd,),
)
result = pc.to_json()
assert (bd,) == result["benchmark_details"]
def test_to_json_empty_benchmark_details_tuple(self):
"""
to_json converts an empty tuple of benchmark_details to None.
"""
pc = PrComment(
optimization_explanation="opt",
best_runtime=1000,
original_runtime=2000,
function_name="fn",
relative_file_path="f.py",
speedup_x="2x",
speedup_pct="50%",
loop_count=1,
report_table={},
benchmark_details=(),
)
result = pc.to_json()
assert result["benchmark_details"] is None
def test_to_json_excludes_async_when_none(self, pr_comment):
"""
to_json omits async throughput keys when both are None.
"""
result = pr_comment.to_json()
assert "original_async_throughput" not in result
assert "best_async_throughput" not in result
def test_to_json_includes_async_when_set(self):
"""
to_json includes async throughput keys when both are set.
"""
pc = PrComment(
optimization_explanation="opt",
best_runtime=1000,
original_runtime=2000,
function_name="fn",
relative_file_path="f.py",
speedup_x="2x",
speedup_pct="50%",
loop_count=1,
report_table={},
original_async_throughput=500,
best_async_throughput=1000,
)
result = pc.to_json()
assert 500 == result["original_async_throughput"]
assert 1000 == result["best_async_throughput"]
def test_to_json_excludes_async_when_only_original_set(self):
"""
to_json omits async keys when only original_async_throughput is set.
"""
pc = PrComment(
optimization_explanation="opt",
best_runtime=1000,
original_runtime=2000,
function_name="fn",
relative_file_path="f.py",
speedup_x="2x",
speedup_pct="50%",
loop_count=1,
report_table={},
original_async_throughput=500,
)
result = pc.to_json()
assert "original_async_throughput" not in result
assert "best_async_throughput" not in result
class TestHumanizeRuntime:
"""Tests for humanize_runtime."""
def test_single_nanosecond(self):
"""
1 ns uses the singular 'nanosecond' unit.
"""
assert "1.00 nanosecond" == humanize_runtime(1)
def test_small_nanoseconds(self):
"""
Values below 1 microsecond use the 'nanoseconds' unit.
"""
assert "500 nanoseconds" == humanize_runtime(500)
def test_microsecond_range(self):
"""
Values in the microsecond range display in microseconds.
"""
result = humanize_runtime(5_000)
assert "microseconds" in result
def test_millisecond_range(self):
"""
Values in the millisecond range display in milliseconds.
"""
result = humanize_runtime(5_000_000)
assert "milliseconds" in result
def test_second_range(self):
"""
Values in the second range display in seconds.
"""
result = humanize_runtime(2_000_000_000)
assert "seconds" in result
def test_minute_range(self):
"""
Values in the minute range display in minutes.
"""
result = humanize_runtime(120_000_000_000)
assert "minutes" in result
def test_hour_range(self):
"""
Values in the hour range display in hours.
"""
result = humanize_runtime(2 * _NS_PER_HOUR)
assert "hours" in result
def test_day_range(self):
"""
Values in the day range display in days.
"""
result = humanize_runtime(2 * _NS_PER_DAY)
assert "days" in result
def test_exact_one_microsecond(self):
"""
Exactly 1 microsecond uses the singular 'microsecond' unit.
"""
result = humanize_runtime(1_000)
assert "microsecond" in result
assert "microseconds" not in result
def test_exact_one_second(self):
"""
Exactly 1 second uses the singular 'second' unit.
"""
result = humanize_runtime(_NS_PER_S)
assert "second" in result
assert "seconds" not in result
def test_exact_one_minute(self):
"""
Exactly 1 minute uses the singular 'minute' unit.
"""
result = humanize_runtime(_NS_PER_MIN)
assert "minute" in result
assert "minutes" not in result
def test_exact_one_hour(self):
"""
Exactly 1 hour uses the singular 'hour' unit.
"""
result = humanize_runtime(_NS_PER_HOUR)
assert "hour" in result
assert "hours" not in result
def test_exact_one_day(self):
"""
Exactly 1 day uses the singular 'day' unit.
"""
result = humanize_runtime(_NS_PER_DAY)
assert "day" in result
assert "days" not in result
def test_returns_string(self):
"""
The return type is always a string.
"""
assert isinstance(humanize_runtime(42), str)
def test_contains_space_between_value_and_unit(self):
"""
The output always has a space between the numeric value and units.
"""
result = humanize_runtime(12345)
parts = result.split(" ")
assert 2 == len(parts)

View file

@ -0,0 +1,701 @@
from __future__ import annotations
import attrs
import pytest
from codeflash_core._model import Candidate
from codeflash_core._pipeline import (
CandidateForest,
CandidateNode,
EvaluationContext,
create_rank_dictionary,
dedup_candidates,
diff_length,
filter_refined_candidates,
performance_gain,
select_best,
)
@pytest.fixture(name="candidate")
def _candidate():
"""
A minimal Candidate for use in tests.
"""
return Candidate(
code="def f(): pass",
explanation="no-op",
candidate_id="c1",
)
@pytest.fixture(name="eval_ctx")
def _eval_ctx():
"""
A fresh EvaluationContext.
"""
return EvaluationContext()
class TestPerformanceGain:
"""Tests for performance_gain."""
def test_two_x_speedup(self):
"""
Original twice as slow yields a gain of 1.0.
"""
assert 1.0 == performance_gain(
original_runtime_ns=200, optimized_runtime_ns=100
)
def test_no_speedup(self):
"""
Equal runtimes yield a gain of 0.0.
"""
assert 0.0 == performance_gain(
original_runtime_ns=100, optimized_runtime_ns=100
)
def test_slowdown(self):
"""
Optimized slower than original yields a negative gain.
"""
assert -0.5 == performance_gain(
original_runtime_ns=100, optimized_runtime_ns=200
)
def test_optimized_zero_returns_zero(self):
"""
Zero optimized runtime returns 0.0 to avoid division by zero.
"""
assert 0.0 == performance_gain(
original_runtime_ns=100, optimized_runtime_ns=0
)
def test_both_zero(self):
"""
Both runtimes zero returns 0.0.
"""
assert 0.0 == performance_gain(
original_runtime_ns=0, optimized_runtime_ns=0
)
def test_original_zero_optimized_nonzero(self):
"""
Zero original with nonzero optimized yields -1.0.
"""
assert -1.0 == performance_gain(
original_runtime_ns=0, optimized_runtime_ns=100
)
def test_large_speedup(self):
"""
Large runtime difference produces correspondingly large gain.
"""
assert 999.0 == performance_gain(
original_runtime_ns=1_000_000, optimized_runtime_ns=1_000
)
class TestDiffLength:
"""Tests for diff_length."""
def test_identical_strings(self):
"""
Identical strings produce a zero-length diff.
"""
assert 0 == diff_length("hello\n", "hello\n")
def test_empty_strings(self):
"""
Two empty strings produce a zero-length diff.
"""
assert 0 == diff_length("", "")
def test_one_line_changed(self):
"""
A single changed line produces a non-zero diff.
"""
result = diff_length("line1\nline2\n", "line1\nchanged\n")
assert result > 0
def test_addition(self):
"""
Adding a line produces a non-zero diff.
"""
result = diff_length("a\n", "a\nb\n")
assert result > 0
def test_deletion(self):
"""
Removing a line produces a non-zero diff.
"""
result = diff_length("a\nb\n", "a\n")
assert result > 0
def test_completely_different(self):
"""
Completely different strings produce a larger diff than a small edit.
"""
small = diff_length("aaa\n", "aab\n")
large = diff_length("aaa\nbbb\nccc\n", "xxx\nyyy\nzzz\n")
assert large > small
class TestCreateRankDictionary:
"""Tests for create_rank_dictionary."""
def test_ascending_order(self):
"""
Already-sorted values get ranks 0, 1, 2.
"""
assert {0: 0, 1: 1, 2: 2} == create_rank_dictionary([10, 20, 30])
def test_descending_order(self):
"""
Reverse-sorted values get inverted ranks.
"""
assert {0: 2, 1: 1, 2: 0} == create_rank_dictionary([30, 20, 10])
def test_single_element(self):
"""
A single element gets rank 0.
"""
assert {0: 0} == create_rank_dictionary([42])
def test_empty_list(self):
"""
An empty list returns an empty dict.
"""
assert {} == create_rank_dictionary([])
def test_duplicate_values(self):
"""
Duplicates still get distinct ranks based on original index.
"""
result = create_rank_dictionary([5, 5, 5])
assert {0, 1, 2} == set(result.values())
class TestCandidateNode:
"""Tests for CandidateNode."""
def test_candidate_id_delegates(self, candidate):
"""
candidate_id property returns the wrapped candidate's id.
"""
node = CandidateNode(candidate=candidate)
assert "c1" == node.candidate_id
def test_is_leaf_true(self, candidate):
"""
A node with no children is a leaf.
"""
node = CandidateNode(candidate=candidate)
assert node.is_leaf() is True
def test_is_leaf_false(self, candidate):
"""
A node with children is not a leaf.
"""
parent = CandidateNode(candidate=candidate)
child_cand = Candidate(
code="def g(): pass",
explanation="child",
candidate_id="c2",
)
child = CandidateNode(candidate=child_cand, parent=parent)
parent.children.append(child)
assert parent.is_leaf() is False
def test_path_to_root_single(self, candidate):
"""
A root node's path is just itself.
"""
node = CandidateNode(candidate=candidate)
path = node.path_to_root()
assert 1 == len(path)
assert candidate is path[0]
def test_path_to_root_chain(self):
"""
A three-deep chain returns root-first order.
"""
c1 = Candidate(code="a", explanation="", candidate_id="c1")
c2 = Candidate(code="b", explanation="", candidate_id="c2")
c3 = Candidate(code="c", explanation="", candidate_id="c3")
root = CandidateNode(candidate=c1)
mid = CandidateNode(candidate=c2, parent=root)
leaf = CandidateNode(candidate=c3, parent=mid)
path = leaf.path_to_root()
assert ["c1", "c2", "c3"] == [c.candidate_id for c in path]
def test_default_children_empty(self, candidate):
"""
A new node starts with an empty children list.
"""
node = CandidateNode(candidate=candidate)
assert [] == node.children
class TestCandidateForest:
"""Tests for CandidateForest."""
def test_add_and_get(self, candidate):
"""
Adding a candidate allows retrieval by id.
"""
forest = CandidateForest()
forest.add(candidate)
node = forest.get("c1")
assert node is not None
assert "c1" == node.candidate_id
def test_len(self):
"""
__len__ returns the number of nodes in the forest.
"""
forest = CandidateForest()
assert 0 == len(forest)
forest.add(Candidate(code="a", explanation="", candidate_id="c1"))
assert 1 == len(forest)
def test_get_missing(self):
"""
Getting a nonexistent id returns None.
"""
forest = CandidateForest()
assert forest.get("nope") is None
def test_parent_child_linking(self):
"""
A candidate with parent_id links to its parent node.
"""
forest = CandidateForest()
parent = Candidate(code="a", explanation="", candidate_id="p1")
child = Candidate(
code="b",
explanation="",
candidate_id="c1",
parent_id="p1",
)
forest.add(parent)
forest.add(child)
child_node = forest.get("c1")
assert child_node is not None
assert child_node.parent is not None
assert "p1" == child_node.parent.candidate_id
def test_child_added_before_parent(self):
"""
Adding a child before its parent creates a placeholder.
"""
forest = CandidateForest()
child = Candidate(
code="b",
explanation="",
candidate_id="c1",
parent_id="p1",
)
forest.add(child)
assert 2 == len(forest)
parent = Candidate(code="a", explanation="", candidate_id="p1")
forest.add(parent)
assert 2 == len(forest)
parent_node = forest.get("p1")
assert parent_node is not None
assert "a" == parent_node.candidate.code
def test_multiple_roots(self):
"""
Candidates without parent_id are independent roots.
"""
forest = CandidateForest()
forest.add(Candidate(code="a", explanation="", candidate_id="r1"))
forest.add(Candidate(code="b", explanation="", candidate_id="r2"))
assert 2 == len(forest)
assert forest.get("r1").parent is None
assert forest.get("r2").parent is None
class TestEvaluationContext:
"""Tests for EvaluationContext."""
def test_record_failed(self, eval_ctx):
"""
record_failed sets correct flags and None values.
"""
eval_ctx.record_failed("c1")
assert eval_ctx.is_correct["c1"] is False
assert eval_ctx.optimized_runtimes["c1"] is None
assert eval_ctx.speedup_ratios["c1"] is None
def test_record_success(self, eval_ctx):
"""
record_success stores runtime, speedup, and correctness.
"""
eval_ctx.record_success("c1", runtime=500.0, speedup=1.5)
assert eval_ctx.is_correct["c1"] is True
assert 500.0 == eval_ctx.optimized_runtimes["c1"]
assert 1.5 == eval_ctx.speedup_ratios["c1"]
def test_get_speedup_missing(self, eval_ctx):
"""
get_speedup returns None for unknown candidates.
"""
assert eval_ctx.get_speedup("unknown") is None
def test_get_runtime_missing(self, eval_ctx):
"""
get_runtime returns None for unknown candidates.
"""
assert eval_ctx.get_runtime("unknown") is None
def test_get_speedup_after_record(self, eval_ctx):
"""
get_speedup returns the recorded value.
"""
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
assert 2.0 == eval_ctx.get_speedup("c1")
def test_get_runtime_after_record(self, eval_ctx):
"""
get_runtime returns the recorded value.
"""
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
assert 100.0 == eval_ctx.get_runtime("c1")
def test_record_line_profile(self, eval_ctx):
"""
record_line_profile stores the result string.
"""
eval_ctx.record_line_profile("c1", "profile output")
assert "profile output" == eval_ctx.line_profiler_results["c1"]
def test_register_new(self, eval_ctx):
"""
register_new stores code-to-id mapping with diff info.
"""
eval_ctx.register_new(
normalized_code="norm",
candidate_id="c1",
flat_code="def f(): pass",
original_flat_code="def f(): return 1",
)
entry = eval_ctx.code_to_id["norm"]
assert "c1" == entry["candidate_id"]
assert "def f(): pass" == entry["shorter_code"]
assert entry["diff_len"] > 0
def test_handle_duplicate_copies_prior_results(self, eval_ctx):
"""
handle_duplicate propagates prior correctness and runtime.
"""
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
eval_ctx.register_new(
normalized_code="norm",
candidate_id="c1",
flat_code="def f(): pass",
original_flat_code="def g(): pass",
)
eval_ctx.handle_duplicate(
candidate_id="c2",
normalized_code="norm",
original_flat_code="def g(): pass",
flat_code="def f(): pass",
)
assert eval_ctx.is_correct["c2"] is True
assert 100.0 == eval_ctx.optimized_runtimes["c2"]
assert 2.0 == eval_ctx.speedup_ratios["c2"]
def test_handle_duplicate_updates_shorter_code(self, eval_ctx):
"""
handle_duplicate replaces shorter_code when the new diff is smaller.
"""
eval_ctx.register_new(
normalized_code="norm",
candidate_id="c1",
flat_code="def f():\n x = 1\n return x\n",
original_flat_code="def f():\n return 1\n",
)
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
eval_ctx.handle_duplicate(
candidate_id="c2",
normalized_code="norm",
original_flat_code="def f():\n return 1\n",
flat_code="def f():\n return 1\n",
)
assert (
"def f():\n return 1\n"
== eval_ctx.code_to_id["norm"]["shorter_code"]
)
def test_starts_empty(self, eval_ctx):
"""
A new EvaluationContext has empty containers.
"""
assert {} == eval_ctx.speedup_ratios
assert {} == eval_ctx.optimized_runtimes
assert {} == eval_ctx.is_correct
assert [] == eval_ctx.valid_candidates
class TestDedupCandidates:
"""Tests for dedup_candidates."""
def test_empty_input(self):
"""
An empty list returns an empty list.
"""
result = dedup_candidates(
[],
normalize_fn=lambda c: c,
original_normalized="original",
)
assert [] == result
def test_removes_identical_to_original(self):
"""
Candidates matching the original are removed.
"""
c1 = Candidate(code="original", explanation="", candidate_id="c1")
result = dedup_candidates(
[c1],
normalize_fn=lambda c: c,
original_normalized="original",
)
assert [] == result
def test_removes_intra_batch_duplicates(self):
"""
Only the first occurrence of a normalized form is kept.
"""
c1 = Candidate(code="opt", explanation="", candidate_id="c1")
c2 = Candidate(code="opt", explanation="", candidate_id="c2")
result = dedup_candidates(
[c1, c2],
normalize_fn=lambda c: c,
original_normalized="original",
)
assert 1 == len(result)
assert "c1" == result[0].candidate_id
def test_removes_cross_batch_duplicates(self):
"""
Candidates already in cross_batch are removed.
"""
c1 = Candidate(code="opt", explanation="", candidate_id="c1")
result = dedup_candidates(
[c1],
normalize_fn=lambda c: c,
original_normalized="original",
cross_batch={"opt": {"candidate_id": "prior"}},
)
assert [] == result
def test_keeps_unique_candidates(self):
"""
All unique candidates are preserved in order.
"""
c1 = Candidate(code="a", explanation="", candidate_id="c1")
c2 = Candidate(code="b", explanation="", candidate_id="c2")
result = dedup_candidates(
[c1, c2],
normalize_fn=lambda c: c,
original_normalized="original",
)
assert 2 == len(result)
assert ["c1", "c2"] == [c.candidate_id for c in result]
def test_normalize_fn_failure_keeps_candidate(self):
"""
A candidate whose normalization raises is still kept.
"""
def bad_normalize(code):
raise ValueError("boom")
c1 = Candidate(code="a", explanation="", candidate_id="c1")
result = dedup_candidates(
[c1],
normalize_fn=bad_normalize,
original_normalized="original",
)
assert 1 == len(result)
assert "c1" == result[0].candidate_id
def test_seen_set_is_populated(self):
"""
The seen set accumulates normalized codes across calls.
"""
seen: set[str] = set()
c1 = Candidate(code="a", explanation="", candidate_id="c1")
dedup_candidates(
[c1],
normalize_fn=lambda c: c,
original_normalized="original",
seen=seen,
)
assert "a" in seen
def test_seen_set_prevents_second_batch(self):
"""
A code seen in a prior call is treated as a duplicate.
"""
seen: set[str] = {"a"}
c1 = Candidate(code="a", explanation="", candidate_id="c1")
result = dedup_candidates(
[c1],
normalize_fn=lambda c: c,
original_normalized="original",
seen=seen,
)
assert [] == result
class TestFilterRefinedCandidates:
"""Tests for filter_refined_candidates."""
def test_under_limit_returns_all(self, eval_ctx):
"""
Fewer candidates than max_candidates returns all of them.
"""
forest = CandidateForest()
candidates = [
Candidate(code="a", explanation="", candidate_id="c1"),
Candidate(code="b", explanation="", candidate_id="c2"),
]
result = filter_refined_candidates(
candidates, eval_ctx, forest, "original", max_candidates=5
)
assert 2 == len(result)
def test_at_limit_returns_all(self, eval_ctx):
"""
Exactly max_candidates returns all of them.
"""
forest = CandidateForest()
candidates = [
Candidate(code="a", explanation="", candidate_id=f"c{i}")
for i in range(5)
]
result = filter_refined_candidates(
candidates, eval_ctx, forest, "original", max_candidates=5
)
assert 5 == len(result)
def test_over_limit_trims(self, eval_ctx):
"""
More candidates than max_candidates trims the result.
"""
forest = CandidateForest()
candidates = [
Candidate(code=f"code{i}", explanation="", candidate_id=f"c{i}")
for i in range(10)
]
result = filter_refined_candidates(
candidates, eval_ctx, forest, "original", max_candidates=3
)
assert 3 == len(result)
def test_prefers_lower_parent_runtime(self):
"""
Candidates with faster parents rank higher.
"""
eval_ctx = EvaluationContext()
forest = CandidateForest()
parent_fast = Candidate(code="pf", explanation="", candidate_id="pf")
parent_slow = Candidate(code="ps", explanation="", candidate_id="ps")
forest.add(parent_fast)
forest.add(parent_slow)
eval_ctx.record_success("pf", runtime=10.0, speedup=5.0)
eval_ctx.record_success("ps", runtime=1000.0, speedup=0.1)
candidates = [
Candidate(
code="slow_child",
explanation="",
candidate_id="c_slow",
parent_id="ps",
),
Candidate(
code="fast_child",
explanation="",
candidate_id="c_fast",
parent_id="pf",
),
] * 4
for i, c in enumerate(candidates):
candidates[i] = attrs.evolve(
c, candidate_id=f"{c.candidate_id}_{i}"
)
result = filter_refined_candidates(
candidates, eval_ctx, forest, "original", max_candidates=2
)
assert 2 == len(result)
assert all("c_fast" in r.candidate_id for r in result)
class TestSelectBest:
"""Tests for select_best."""
def test_empty_returns_none(self, eval_ctx):
"""
No candidates returns None.
"""
assert select_best(eval_ctx, 1000, [], []) is None
def test_single_candidate(self, eval_ctx):
"""
A single candidate is always selected.
"""
eval_ctx.record_success("c1", runtime=50.0, speedup=1.0)
assert "c1" == select_best(eval_ctx, 1000, [10], ["c1"])
def test_picks_faster_candidate(self, eval_ctx):
"""
The faster candidate wins when diff lengths are equal.
"""
eval_ctx.record_success("c1", runtime=500.0, speedup=1.0)
eval_ctx.record_success("c2", runtime=100.0, speedup=4.0)
result = select_best(eval_ctx, 1000, [10, 10], ["c1", "c2"])
assert "c2" == result
def test_diff_length_contributes_to_score(self, eval_ctx):
"""
Candidate with shorter diff and equal runtime wins when listed first.
"""
eval_ctx.record_success("short_diff", runtime=100.0, speedup=1.0)
eval_ctx.record_success("long_diff", runtime=100.0, speedup=1.0)
result = select_best(
eval_ctx, 1000, [5, 500], ["short_diff", "long_diff"]
)
assert "short_diff" == result
def test_runtime_weighted_more_than_diff(self, eval_ctx):
"""
A faster candidate wins even with a longer diff.
"""
eval_ctx.record_success("c1", runtime=500.0, speedup=1.0)
eval_ctx.record_success("c2", runtime=50.0, speedup=9.0)
result = select_best(eval_ctx, 1000, [5, 100], ["c1", "c2"])
assert "c2" == result
def test_missing_runtime_uses_original(self, eval_ctx):
"""
A candidate without recorded runtime uses original_runtime_ns.
"""
eval_ctx.record_success("c1", runtime=50.0, speedup=1.0)
result = select_best(eval_ctx, 1000, [10, 10], ["c1", "c2"])
assert "c1" == result

View file

@ -0,0 +1,336 @@
from __future__ import annotations
import textwrap
from codeflash_python.analysis._normalizer import normalize_python_code
class TestBasicNormalization:
"""Tests for local variable renaming to canonical forms."""
def test_single_variable_renamed(self):
"""
A single local variable is renamed to var_0.
"""
code = textwrap.dedent("""\
def f():
x = 1
return x
""")
result = normalize_python_code(code)
assert "var_0 = 1" in result
assert "return var_0" in result
def test_multiple_variables_renamed_sequentially(self):
"""
Multiple local variables are renamed to var_0, var_1, etc. in order of first assignment.
"""
code = textwrap.dedent("""\
def f():
alpha = 1
beta = 2
return alpha + beta
""")
result = normalize_python_code(code)
assert "var_0 = 1" in result
assert "var_1 = 2" in result
assert "return var_0 + var_1" in result
def test_same_variable_reused_gets_same_canonical_name(self):
"""
A variable assigned and then loaded multiple times uses the same canonical name throughout.
"""
code = textwrap.dedent("""\
def f():
total = 0
total = total + 1
return total
""")
result = normalize_python_code(code)
assert "var_0 = 0" in result
assert "var_0 = var_0 + 1" in result
assert "return var_0" in result
class TestFunctionNamePreservation:
"""Tests that function and class names are not renamed."""
def test_function_name_preserved(self):
"""
The function name itself is not renamed.
"""
code = textwrap.dedent("""\
def compute_sum():
x = 1
return x
""")
result = normalize_python_code(code)
assert "def compute_sum():" in result
def test_class_name_preserved(self):
"""
The class name is not renamed.
"""
code = textwrap.dedent("""\
class MyClass:
def method(self):
x = 1
return x
""")
result = normalize_python_code(code)
assert "class MyClass:" in result
class TestParameterPreservation:
"""Tests that function parameters are not renamed."""
def test_positional_parameter_preserved(self):
"""
Positional parameters keep their original names.
"""
code = textwrap.dedent("""\
def f(value):
x = value + 1
return x
""")
result = normalize_python_code(code)
assert "value" in result
assert "var_0 = value + 1" in result
def test_args_and_kwargs_preserved(self):
"""
*args and **kwargs parameters keep their original names.
"""
code = textwrap.dedent("""\
def f(*args, **kwargs):
x = args
y = kwargs
return x, y
""")
result = normalize_python_code(code)
assert "args" in result
assert "kwargs" in result
def test_keyword_only_parameter_preserved(self):
"""
Keyword-only parameters keep their original names.
"""
code = textwrap.dedent("""\
def f(*, key):
x = key
return x
""")
result = normalize_python_code(code)
assert "key" in result
assert "var_0 = key" in result
def test_self_parameter_preserved(self):
"""
The self parameter on methods is not renamed.
"""
code = textwrap.dedent("""\
class C:
def method(self):
x = self
return x
""")
result = normalize_python_code(code)
assert "self" in result
class TestImportPreservation:
"""Tests that imported names are not renamed."""
def test_import_name_preserved(self):
"""
Names brought in by import statements are not renamed.
"""
code = textwrap.dedent("""\
import os
def f():
x = os.path.join('a', 'b')
return x
""")
result = normalize_python_code(code)
assert "os" in result
def test_from_import_name_preserved(self):
"""
Names brought in by from-import statements are not renamed.
"""
code = textwrap.dedent("""\
from os.path import join
def f():
x = join('a', 'b')
return x
""")
result = normalize_python_code(code)
assert "join" in result
def test_aliased_import_preserved(self):
"""
Aliased import names are preserved using the alias.
"""
code = textwrap.dedent("""\
import numpy as np
def f():
x = np.array([1])
return x
""")
result = normalize_python_code(code)
assert "np" in result
class TestDocstringRemoval:
"""Tests for docstring handling."""
def test_docstring_removed_by_default(self):
"""
Docstrings are removed when remove_docstrings=True (the default).
"""
code = textwrap.dedent('''\
def f():
"""This is a docstring."""
x = 1
return x
''')
result = normalize_python_code(code)
assert "docstring" not in result
def test_module_docstring_removed(self):
"""
Module-level docstrings are removed.
"""
code = textwrap.dedent('''\
"""Module docstring."""
def f():
x = 1
return x
''')
result = normalize_python_code(code)
assert "Module docstring" not in result
def test_class_docstring_removed(self):
"""
Class-level docstrings are removed.
"""
code = textwrap.dedent('''\
class C:
"""Class docstring."""
def method(self):
x = 1
return x
''')
result = normalize_python_code(code)
assert "Class docstring" not in result
class TestDocstringPreservation:
"""Tests for docstring preservation when remove_docstrings=False."""
def test_docstring_preserved_when_flag_false(self):
"""
Docstrings are kept when remove_docstrings=False.
"""
code = textwrap.dedent('''\
def f():
"""This is a docstring."""
x = 1
return x
''')
result = normalize_python_code(code, remove_docstrings=False)
assert "This is a docstring." in result
def test_all_docstrings_preserved_when_flag_false(self):
"""
Module, class, and function docstrings are all kept when remove_docstrings=False.
"""
code = textwrap.dedent('''\
"""Module doc."""
class C:
"""Class doc."""
def method(self):
"""Method doc."""
x = 1
return x
''')
result = normalize_python_code(code, remove_docstrings=False)
assert "Module doc." in result
assert "Class doc." in result
assert "Method doc." in result
class TestSyntaxErrorHandling:
"""Tests that invalid Python input is returned unchanged."""
def test_syntax_error_returns_original(self):
"""
Code with a syntax error is returned as-is without modification.
"""
code = "def f(\n x = :\n"
assert code == normalize_python_code(code)
def test_incomplete_code_returns_original(self):
"""
Incomplete code that cannot be parsed is returned unchanged.
"""
code = "def f("
assert code == normalize_python_code(code)
class TestEmptyInput:
"""Tests for empty and whitespace-only input."""
def test_empty_string(self):
"""
An empty string produces an empty string.
"""
assert "" == normalize_python_code("")
def test_whitespace_only(self):
"""
Whitespace-only input normalizes to an empty string.
"""
result = normalize_python_code(" \n\n ")
assert "" == result
class TestMultipleFunctions:
"""Tests for code containing multiple function definitions."""
def test_variables_in_separate_functions_get_independent_names(self):
"""
Local variables in different functions are normalized independently per scope.
"""
code = textwrap.dedent("""\
def first():
alpha = 1
return alpha
def second():
beta = 2
return beta
""")
result = normalize_python_code(code)
assert "def first():" in result
assert "def second():" in result
assert "var_0 = 1" in result
assert "var_0 = 2" in result
def test_nested_function_variables_normalized(self):
"""
Variables in nested functions are normalized, inheriting the parent scope counter.
"""
code = textwrap.dedent("""\
def outer():
x = 1
def inner():
y = 2
return y
return x
""")
result = normalize_python_code(code)
assert "def outer():" in result
assert "def inner():" in result
assert "var_0 = 1" in result
assert "var_1 = 2" in result