diff --git a/packages/codeflash-core/tests/test_danom_result.py b/packages/codeflash-core/tests/test_danom_result.py new file mode 100644 index 0000000..98dd97b --- /dev/null +++ b/packages/codeflash-core/tests/test_danom_result.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import attrs +import pytest + +from codeflash_core.danom.result import Err, Ok, Result + + +class TestOk: + """Tests for Ok variant of Result.""" + + def test_construction_with_value(self): + """ + Ok wraps the given value. + """ + ok = Ok(42) + + assert 42 == ok.inner + + def test_construction_default_none(self): + """ + Ok defaults to None when no value is given. + """ + ok = Ok() + + assert ok.inner is None + + def test_is_ok_returns_true(self): + """ + is_ok returns True for Ok. + """ + assert Ok(1).is_ok() is True + + def test_unwrap_returns_inner(self): + """ + unwrap returns the wrapped value. + """ + assert "hello" == Ok("hello").unwrap() + + def test_unwrap_none_value(self): + """ + unwrap returns None when Ok wraps None. + """ + assert Ok(None).unwrap() is None + + def test_map_applies_func(self): + """ + map applies the function to the inner value and wraps in Ok. + """ + result = Ok(3).map(lambda x: x * 2) + + assert Ok(6) == result + + def test_map_with_extra_args(self): + """ + map passes extra positional and keyword args to the function. + """ + result = Ok(10).map(lambda x, y, z=0: x + y + z, 5, z=3) + + assert Ok(18) == result + + def test_map_err_is_noop(self): + """ + map_err returns self unchanged for Ok. + """ + ok = Ok(42) + result = ok.map_err(lambda e: "transformed") + + assert result is ok + + def test_and_then_returns_func_result(self): + """ + and_then applies the function and returns its raw result. + """ + result = Ok(5).and_then(lambda x: Ok(x + 1)) + + assert Ok(6) == result + + def test_and_then_with_extra_args(self): + """ + and_then passes extra positional and keyword args to the function. + """ + result = Ok(1).and_then(lambda x, y, z=0: Ok(x + y + z), 2, z=3) + + assert Ok(6) == result + + def test_and_then_can_return_err(self): + """ + and_then can return an Err from the function. + """ + result = Ok(5).and_then(lambda x: Err("fail")) + + assert Err("fail") == result + + def test_or_else_is_noop(self): + """ + or_else returns self unchanged for Ok. + """ + ok = Ok(42) + result = ok.or_else(lambda e: Ok(0)) + + assert result is ok + + +class TestErr: + """Tests for Err variant of Result.""" + + def test_construction_with_error(self): + """ + Err wraps the given error value. + """ + err = Err("bad") + + assert "bad" == err.error + + def test_construction_default_none(self): + """ + Err defaults to None error when no value is given. + """ + err = Err() + + assert err.error is None + + def test_is_ok_returns_false(self): + """ + is_ok returns False for Err. + """ + assert Err("x").is_ok() is False + + def test_unwrap_raises_wrapped_exception(self): + """ + unwrap re-raises the wrapped Exception. + """ + exc = ValueError("boom") + + with pytest.raises(ValueError, match="boom"): + Err(exc).unwrap() + + def test_unwrap_non_exception_raises_valueerror(self): + """ + unwrap raises ValueError when the error is not an Exception. + """ + with pytest.raises( + ValueError, match="Err does not have a caught error" + ): + Err("not an exception").unwrap() + + def test_unwrap_none_error_raises_valueerror(self): + """ + unwrap raises ValueError when the error is None. + """ + with pytest.raises( + ValueError, match="Err does not have a caught error" + ): + Err(None).unwrap() + + def test_map_is_noop(self): + """ + map returns self unchanged for Err. + """ + err = Err("fail") + result = err.map(lambda x: x * 2) + + assert result is err + + def test_map_err_applies_func(self): + """ + map_err applies the function to the error value. + """ + result = Err("bad").map_err(lambda e: e.upper()) + + assert Err("BAD") == result + + def test_map_err_with_extra_args(self): + """ + map_err passes extra positional and keyword args to the function. + """ + result = Err(10).map_err(lambda e, x, y=0: e + x + y, 5, y=3) + + assert Err(18) == result + + def test_and_then_is_noop(self): + """ + and_then returns self unchanged for Err. + """ + err = Err("fail") + result = err.and_then(lambda x: Ok(x + 1)) + + assert result is err + + def test_or_else_applies_func(self): + """ + or_else applies the function to the error and returns its result. + """ + result = Err("fail").or_else(lambda e: Ok("recovered")) + + assert Ok("recovered") == result + + def test_or_else_with_extra_args(self): + """ + or_else passes extra positional and keyword args to the function. + """ + result = Err("x").or_else( + lambda e, suffix, sep="-": Ok(e + sep + suffix), "y", sep="+" + ) + + assert Ok("x+y") == result + + def test_or_else_can_return_err(self): + """ + or_else can return another Err from the function. + """ + result = Err("first").or_else(lambda e: Err("second")) + + assert Err("second") == result + + def test_input_args_stored(self): + """ + input_args are stored on the Err instance. + """ + args = (("a", "b"), {"key": "val"}) + err = Err("fail", input_args=args) + + assert args == err.input_args + + def test_input_args_default_empty_tuple(self): + """ + input_args defaults to an empty tuple. + """ + assert () == Err("fail").input_args + + def test_traceback_stored(self): + """ + traceback string is stored on the Err instance. + """ + err = Err("fail", traceback="line 1\nline 2") + + assert "line 1\nline 2" == err.traceback + + def test_traceback_default_empty_string(self): + """ + traceback defaults to an empty string. + """ + assert "" == Err("fail").traceback + + +class TestErrEquality: + """Tests for Err.__eq__ and __hash__.""" + + def test_equal_errs_with_same_exception_type_and_message(self): + """ + Two Errs with same exception type, message, and input_args are equal. + """ + assert Err(ValueError("x")) == Err(ValueError("x")) + + def test_unequal_errs_different_message(self): + """ + Errs with different messages are not equal. + """ + assert Err(ValueError("x")) != Err(ValueError("y")) + + def test_unequal_errs_different_type(self): + """ + Errs with different exception types are not equal. + """ + assert Err(ValueError("x")) != Err(TypeError("x")) + + def test_unequal_errs_different_input_args(self): + """ + Errs with different input_args are not equal. + """ + err1 = Err("x", input_args=(("a",), {})) + err2 = Err("x", input_args=(("b",), {})) + + assert err1 != err2 + + def test_err_not_equal_to_non_err(self): + """ + Err is not equal to a non-Err object. + """ + assert Err("x") != "x" + assert Err(42) != 42 + + def test_equal_errs_with_string_errors(self): + """ + Two Errs with identical string errors are equal. + """ + assert Err("fail") == Err("fail") + + def test_hash_consistent_with_equality(self): + """ + Equal Err instances produce the same hash. + """ + err1 = Err(ValueError("x")) + err2 = Err(ValueError("x")) + + assert hash(err1) == hash(err2) + + def test_hash_differs_for_unequal(self): + """ + Different Err instances are unlikely to share a hash. + """ + assert hash(Err("a")) != hash(Err("b")) + + +class TestOkEquality: + """Tests for Ok equality and hashing (attrs-generated).""" + + def test_equal_ok_values(self): + """ + Two Ok instances with the same inner value are equal. + """ + assert Ok(42) == Ok(42) + + def test_unequal_ok_values(self): + """ + Two Ok instances with different inner values are not equal. + """ + assert Ok(1) != Ok(2) + + def test_ok_not_equal_to_err(self): + """ + Ok and Err are never equal. + """ + assert Ok(1) != Err(1) + + def test_ok_not_equal_to_plain_value(self): + """ + Ok is not equal to a plain unwrapped value. + """ + assert Ok(42) != 42 + + def test_hash_consistent_with_equality(self): + """ + Equal Ok instances produce the same hash. + """ + assert hash(Ok("a")) == hash(Ok("a")) + + def test_ok_usable_in_sets(self): + """ + Ok instances can be stored in a set. + """ + s = {Ok(1), Ok(1), Ok(2)} + + assert 2 == len(s) + + +class TestResultStaticMethods: + """Tests for Result.unit, result_is_ok, and result_unwrap.""" + + def test_unit_wraps_in_ok(self): + """ + Result.unit wraps a value in Ok. + """ + result = Result.unit(99) + + assert Ok(99) == result + + def test_result_is_ok_for_ok(self): + """ + result_is_ok returns True for an Ok. + """ + assert Result.result_is_ok(Ok(1)) is True + + def test_result_is_ok_for_err(self): + """ + result_is_ok returns False for an Err. + """ + assert Result.result_is_ok(Err("fail")) is False + + def test_result_unwrap_ok(self): + """ + result_unwrap returns the inner value for Ok. + """ + assert 42 == Result.result_unwrap(Ok(42)) + + def test_result_unwrap_err_raises(self): + """ + result_unwrap raises for an Err with an Exception. + """ + with pytest.raises(RuntimeError, match="boom"): + Result.result_unwrap(Err(RuntimeError("boom"))) + + +class TestNestedResult: + """Tests for nested Result values.""" + + def test_ok_wrapping_ok(self): + """ + Ok can wrap another Ok. + """ + nested = Ok(Ok(42)) + + assert Ok(42) == nested.unwrap() + assert 42 == nested.unwrap().unwrap() + + def test_ok_wrapping_err(self): + """ + Ok can wrap an Err. + """ + nested = Ok(Err("inner")) + inner = nested.unwrap() + + assert False is inner.is_ok() + assert "inner" == inner.error + + def test_map_on_nested_ok(self): + """ + map on a nested Ok applies the function to the inner Result. + """ + nested = Ok(Ok(5)) + result = nested.map(lambda inner_ok: inner_ok.map(lambda x: x * 2)) + + assert Ok(Ok(10)) == result + + +class TestErrTracebackExtraction: + """Tests for Err traceback detail extraction.""" + + def test_details_populated_from_exception_traceback(self): + """ + details list is populated when error is an Exception with a traceback. + """ + try: + msg = "test error" + raise ValueError(msg) # noqa: TRY301 + except ValueError as exc: + err = Err(exc) + + assert len(err.details) > 0 + assert "file" in err.details[0] + assert "func" in err.details[0] + assert "line_no" in err.details[0] + assert "locals" in err.details[0] + + def test_details_empty_for_non_exception(self): + """ + details list is empty when error is not an Exception. + """ + err = Err("just a string") + + assert [] == err.details + + def test_details_empty_for_exception_without_traceback(self): + """ + details list is empty for an Exception created without raising. + """ + err = Err(ValueError("no traceback")) + + assert [] == err.details + + +class TestErrImmutability: + """Tests for frozen behavior of Ok and Err.""" + + def test_ok_is_frozen(self): + """ + Ok instances cannot be mutated. + """ + ok = Ok(42) + + with pytest.raises(attrs.exceptions.FrozenInstanceError): + ok.inner = 99 # type: ignore[misc] + + def test_err_is_frozen(self): + """ + Err instances cannot be mutated. + """ + err = Err("fail") + + with pytest.raises(attrs.exceptions.FrozenInstanceError): + err.error = "other" # type: ignore[misc] + + +class TestErrValidators: + """Tests for Err field validators.""" + + def test_input_args_must_be_tuple(self): + """ + input_args rejects non-tuple values. + """ + with pytest.raises(TypeError): + Err("fail", input_args=["not", "a", "tuple"]) # type: ignore[arg-type] + + def test_traceback_must_be_string(self): + """ + traceback rejects non-string values. + """ + with pytest.raises(TypeError): + Err("fail", traceback=123) # type: ignore[arg-type] diff --git a/packages/codeflash-core/tests/test_danom_stream.py b/packages/codeflash-core/tests/test_danom_stream.py new file mode 100644 index 0000000..3e87e3a --- /dev/null +++ b/packages/codeflash-core/tests/test_danom_stream.py @@ -0,0 +1,761 @@ +from __future__ import annotations + +import asyncio + +import attrs +import pytest + +from codeflash_core.danom.stream import ( + _FILTER, + _MAP, + _TAP, + Stream, + _apply_fns, + _Nothing, + _par_apply_fns, +) + + +class TestStreamFromIterable: + """Tests for Stream.from_iterable.""" + + def test_from_list(self): + """ + A list is converted into a Stream with a tuple seq. + """ + s = Stream.from_iterable([1, 2, 3]) + + assert (1, 2, 3) == s.seq + + def test_from_tuple(self): + """ + A tuple is accepted as an iterable. + """ + s = Stream.from_iterable((4, 5)) + + assert (4, 5) == s.seq + + def test_from_generator(self): + """ + A generator is consumed and stored as a tuple. + """ + s = Stream.from_iterable(x * 2 for x in range(3)) + + assert (0, 2, 4) == s.seq + + def test_from_empty(self): + """ + An empty iterable produces an empty stream. + """ + s = Stream.from_iterable([]) + + assert () == s.seq + + def test_from_string(self): + """ + A string iterable yields individual characters. + """ + s = Stream.from_iterable("abc") + + assert ("a", "b", "c") == s.seq + + def test_from_range(self): + """ + A range object is accepted as an iterable. + """ + s = Stream.from_iterable(range(4)) + + assert (0, 1, 2, 3) == s.seq + + def test_ops_default_empty(self): + """ + A fresh stream has no queued operations. + """ + s = Stream.from_iterable([1]) + + assert () == s.ops + + +class TestStreamBool: + """Tests for Stream.__bool__.""" + + def test_truthy_when_nonempty(self): + """ + A stream with elements is truthy. + """ + s = Stream.from_iterable([1]) + + assert bool(s) is True + + def test_falsy_when_empty(self): + """ + An empty stream is falsy. + """ + s = Stream.from_iterable([]) + + assert bool(s) is False + + +class TestStreamMap: + """Tests for Stream.map.""" + + def test_single_map(self): + """ + A single map function is applied to each element. + """ + result = ( + Stream.from_iterable([1, 2, 3]).map(lambda x: x * 10).collect() + ) + + assert (10, 20, 30) == result + + def test_multiple_fns_in_one_call(self): + """ + Multiple functions passed to a single map call are applied sequentially. + """ + result = ( + Stream.from_iterable([1, 2]) + .map(lambda x: x + 1, lambda x: x * 3) + .collect() + ) + + assert (6, 9) == result + + def test_map_empty_stream(self): + """ + Mapping over an empty stream produces an empty tuple. + """ + result = Stream.from_iterable([]).map(lambda x: x + 1).collect() + + assert () == result + + def test_map_preserves_immutability(self): + """ + Calling map returns a new stream without mutating the original. + """ + s1 = Stream.from_iterable([1, 2]) + s2 = s1.map(lambda x: x * 2) + + assert () == s1.ops + assert 1 == len(s2.ops) + assert s1 is not s2 + + +class TestStreamFilter: + """Tests for Stream.filter.""" + + def test_single_filter(self): + """ + A single filter predicate removes non-matching elements. + """ + result = ( + Stream.from_iterable([1, 2, 3, 4]) + .filter(lambda x: x % 2 == 0) + .collect() + ) + + assert (2, 4) == result + + def test_filter_removes_all(self): + """ + A predicate that matches nothing produces an empty tuple. + """ + result = ( + Stream.from_iterable([1, 3, 5]).filter(lambda x: x > 100).collect() + ) + + assert () == result + + def test_filter_keeps_all(self): + """ + A predicate that matches everything keeps all elements. + """ + result = ( + Stream.from_iterable([2, 4, 6]) + .filter(lambda x: x % 2 == 0) + .collect() + ) + + assert (2, 4, 6) == result + + def test_multiple_filters(self): + """ + Multiple filter predicates in one call are applied sequentially. + """ + result = ( + Stream.from_iterable(range(20)) + .filter(lambda x: x % 2 == 0, lambda x: x > 5) + .collect() + ) + + assert (6, 8, 10, 12, 14, 16, 18) == result + + def test_filter_empty_stream(self): + """ + Filtering an empty stream produces an empty tuple. + """ + result = Stream.from_iterable([]).filter(lambda x: True).collect() + + assert () == result + + +class TestStreamTap: + """Tests for Stream.tap.""" + + def test_tap_side_effect(self): + """ + Tap executes a side effect without altering elements. + """ + seen = [] + result = Stream.from_iterable([1, 2, 3]).tap(seen.append).collect() + + assert (1, 2, 3) == result + assert [1, 2, 3] == seen + + def test_tap_receives_deepcopy(self): + """ + Tap receives a deep copy so mutations do not affect the stream. + """ + originals = [{"a": 1}, {"a": 2}] + captured = [] + + def mutating_tap(x): + x["a"] = 999 + captured.append(x) + + result = Stream.from_iterable(originals).tap(mutating_tap).collect() + + assert ({"a": 1}, {"a": 2}) == result + assert [{"a": 999}, {"a": 999}] == captured + + +class TestStreamChaining: + """Tests for chaining multiple operations.""" + + def test_map_then_filter(self): + """ + Map followed by filter applies both in order. + """ + result = ( + Stream.from_iterable([1, 2, 3, 4, 5]) + .map(lambda x: x * 2) + .filter(lambda x: x > 4) + .collect() + ) + + assert (6, 8, 10) == result + + def test_filter_then_map(self): + """ + Filter followed by map applies both in order. + """ + result = ( + Stream.from_iterable([1, 2, 3, 4, 5]) + .filter(lambda x: x > 2) + .map(lambda x: x**2) + .collect() + ) + + assert (9, 16, 25) == result + + def test_map_filter_tap_chain(self): + """ + All three operation types can be chained together. + """ + tapped = [] + result = ( + Stream.from_iterable([1, 2, 3, 4]) + .map(lambda x: x + 10) + .filter(lambda x: x % 2 == 0) + .tap(tapped.append) + .collect() + ) + + assert (12, 14) == result + assert [12, 14] == tapped + + def test_multiple_map_calls(self): + """ + Multiple chained map calls compose sequentially. + """ + result = ( + Stream.from_iterable([1, 2]) + .map(lambda x: x + 1) + .map(lambda x: x * 10) + .collect() + ) + + assert (20, 30) == result + + +class TestStreamFold: + """Tests for Stream.fold.""" + + def test_sum(self): + """ + Fold with addition accumulates a sum. + """ + result = Stream.from_iterable([1, 2, 3, 4]).fold( + 0, lambda acc, x: acc + x + ) + + assert 10 == result + + def test_fold_with_operations(self): + """ + Fold applies queued operations before reducing. + """ + result = ( + Stream.from_iterable([1, 2, 3, 4]) + .filter(lambda x: x % 2 == 0) + .map(lambda x: x * 10) + .fold(0, lambda acc, x: acc + x) + ) + + assert 60 == result + + def test_fold_empty_stream(self): + """ + Folding an empty stream returns the initial value. + """ + result = Stream.from_iterable([]).fold(42, lambda acc, x: acc + x) + + assert 42 == result + + def test_fold_string_concat(self): + """ + Fold can concatenate strings. + """ + result = Stream.from_iterable(["a", "b", "c"]).fold( + "", lambda acc, x: acc + x + ) + + assert "abc" == result + + +class TestStreamPartition: + """Tests for Stream.partition.""" + + def test_basic_partition(self): + """ + Partition splits elements by the predicate into two streams. + """ + truthy, falsy = Stream.from_iterable([1, 2, 3, 4, 5]).partition( + lambda x: x > 3 + ) + + assert (4, 5) == truthy.collect() + assert (1, 2, 3) == falsy.collect() + + def test_partition_all_true(self): + """ + When all elements match, the false stream is empty. + """ + truthy, falsy = Stream.from_iterable([2, 4, 6]).partition( + lambda x: x % 2 == 0 + ) + + assert (2, 4, 6) == truthy.collect() + assert () == falsy.collect() + + def test_partition_all_false(self): + """ + When no elements match, the true stream is empty. + """ + truthy, falsy = Stream.from_iterable([1, 3, 5]).partition( + lambda x: x % 2 == 0 + ) + + assert () == truthy.collect() + assert (1, 3, 5) == falsy.collect() + + def test_partition_empty(self): + """ + Partitioning an empty stream yields two empty streams. + """ + truthy, falsy = Stream.from_iterable([]).partition(lambda x: x > 0) + + assert () == truthy.collect() + assert () == falsy.collect() + + def test_partition_applies_pending_ops(self): + """ + Partition materializes pending operations before splitting. + """ + truthy, falsy = ( + Stream.from_iterable([1, 2, 3, 4]) + .map(lambda x: x * 10) + .partition(lambda x: x >= 30) + ) + + assert (30, 40) == truthy.collect() + assert (10, 20) == falsy.collect() + + +class TestStreamCollect: + """Tests for Stream.collect.""" + + def test_collect_no_ops(self): + """ + Collecting with no operations returns the original elements. + """ + result = Stream.from_iterable([1, 2, 3]).collect() + + assert (1, 2, 3) == result + + def test_collect_returns_tuple(self): + """ + Collect always returns a tuple. + """ + result = Stream.from_iterable([1]).collect() + + assert isinstance(result, tuple) + + def test_collect_single_element(self): + """ + A single-element stream collects to a one-element tuple. + """ + result = Stream.from_iterable([42]).map(lambda x: x + 1).collect() + + assert (43,) == result + + +class TestStreamParCollect: + """Tests for Stream.par_collect.""" + + def test_par_collect_threads(self): + """ + Parallel collection with threads produces the same result as sequential. + """ + result = ( + Stream.from_iterable(range(20)) + .map(lambda x: x * 2) + .par_collect(workers=2, use_threads=True) + ) + + assert tuple(x * 2 for x in range(20)) == result + + def test_par_collect_empty(self): + """ + Parallel collection of an empty stream returns an empty tuple. + """ + result = Stream.from_iterable([]).par_collect( + workers=2, use_threads=True + ) + + assert () == result + + def test_par_collect_with_filter(self): + """ + Parallel collection respects filter operations. + """ + result = ( + Stream.from_iterable(range(10)) + .filter(lambda x: x % 2 == 0) + .par_collect(workers=2, use_threads=True) + ) + + assert (0, 2, 4, 6, 8) == result + + +class TestStreamAsyncCollect: + """Tests for Stream.async_collect.""" + + @pytest.mark.asyncio + async def test_async_map(self): + """ + Async map functions are awaited during collection. + """ + + async def times_ten(x): + return x * 10 + + result = await ( + Stream.from_iterable([1, 2, 3]).map(times_ten).async_collect() + ) + + assert (10, 20, 30) == result + + @pytest.mark.asyncio + async def test_async_collect_no_ops(self): + """ + Async collect with no ops falls back to synchronous collect. + """ + result = await Stream.from_iterable([1, 2]).async_collect() + + assert (1, 2) == result + + @pytest.mark.asyncio + async def test_async_filter(self): + """ + Async filter predicates remove non-matching elements. + """ + + async def gt_two(x): + return x > 2 + + result = await ( + Stream.from_iterable([1, 2, 3, 4]).filter(gt_two).async_collect() + ) + + assert (3, 4) == result + + @pytest.mark.asyncio + async def test_async_tap(self): + """ + Async tap executes side effects without altering elements. + """ + seen = [] + + async def record(x): + seen.append(x) + + result = await ( + Stream.from_iterable([10, 20]).tap(record).async_collect() + ) + + assert (10, 20) == result + assert [10, 20] == seen + + +class TestApplyFns: + """Tests for the _apply_fns helper.""" + + def test_map_op(self): + """ + MAP operations transform elements. + """ + ops = ((_MAP, lambda x: x + 1),) + result = tuple(_apply_fns((1, 2, 3), ops)) + + assert (2, 3, 4) == result + + def test_filter_op(self): + """ + FILTER operations remove non-matching elements. + """ + ops = ((_FILTER, lambda x: x > 2),) + result = tuple(_apply_fns((1, 2, 3, 4), ops)) + + assert (3, 4) == result + + def test_tap_op(self): + """ + TAP operations run side effects and pass elements through. + """ + seen = [] + ops = ((_TAP, seen.append),) + result = tuple(_apply_fns((1, 2), ops)) + + assert (1, 2) == result + assert 2 == len(seen) + + def test_combined_ops(self): + """ + Mixed MAP, FILTER, and TAP ops are applied in sequence. + """ + tapped = [] + ops = ( + (_MAP, lambda x: x * 2), + (_FILTER, lambda x: x > 4), + (_TAP, tapped.append), + ) + result = tuple(_apply_fns((1, 2, 3, 4), ops)) + + assert (6, 8) == result + assert [6, 8] == tapped + + def test_empty_ops(self): + """ + No operations yields elements unchanged. + """ + result = tuple(_apply_fns((1, 2), ())) + + assert (1, 2) == result + + def test_empty_elements(self): + """ + Empty elements yields nothing regardless of operations. + """ + ops = ((_MAP, lambda x: x + 1),) + result = tuple(_apply_fns((), ops)) + + assert () == result + + def test_filter_short_circuits(self): + """ + Once a filter rejects an element, subsequent ops are skipped. + """ + map_calls = [] + + def tracking_map(x): + map_calls.append(x) + return x + + ops = ( + (_FILTER, lambda x: x > 5), + (_MAP, tracking_map), + ) + tuple(_apply_fns((1, 10), ops)) + + assert [10] == map_calls + + +class TestParApplyFns: + """Tests for the _par_apply_fns helper.""" + + def test_returns_tuple(self): + """ + _par_apply_fns returns a tuple, not a generator. + """ + result = _par_apply_fns((1, 2, 3), ((_MAP, lambda x: x + 1),)) + + assert isinstance(result, tuple) + assert (2, 3, 4) == result + + def test_empty_elements(self): + """ + Empty elements produce an empty tuple. + """ + result = _par_apply_fns((), ((_MAP, lambda x: x),)) + + assert () == result + + def test_filter_and_map(self): + """ + Combined filter and map work in the eager variant. + """ + ops = ( + (_FILTER, lambda x: x % 2 == 0), + (_MAP, lambda x: x * 10), + ) + result = _par_apply_fns((1, 2, 3, 4), ops) + + assert (20, 40) == result + + +class TestNothing: + """Tests for the _Nothing sentinel.""" + + def test_nothing_is_singleton(self): + """ + NOTHING is the sole member of the _Nothing enum. + """ + assert 1 == len(_Nothing) + assert _Nothing.NOTHING is _Nothing.NOTHING + + def test_nothing_is_not_equal_to_none(self): + """ + NOTHING is distinct from None. + """ + assert _Nothing.NOTHING != None # noqa: E711 + + +class TestStreamImmutability: + """Tests for Stream's frozen/immutable behavior.""" + + def test_cannot_set_seq(self): + """ + Assigning to seq on a frozen stream raises an error. + """ + s = Stream.from_iterable([1, 2]) + + with pytest.raises(attrs.exceptions.FrozenInstanceError): + s.seq = (3, 4) + + def test_cannot_set_ops(self): + """ + Assigning to ops on a frozen stream raises an error. + """ + s = Stream.from_iterable([1, 2]) + + with pytest.raises(attrs.exceptions.FrozenInstanceError): + s.ops = () + + def test_seq_must_be_tuple(self): + """ + Constructing a stream with a non-tuple seq raises TypeError. + """ + with pytest.raises(TypeError): + Stream(seq=[1, 2, 3]) + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_single_element_map(self): + """ + A stream with one element processes map correctly. + """ + result = Stream.from_iterable([5]).map(lambda x: x * 3).collect() + + assert (15,) == result + + def test_single_element_filter_keeps(self): + """ + A single element matching the filter is kept. + """ + result = Stream.from_iterable([5]).filter(lambda x: x > 0).collect() + + assert (5,) == result + + def test_single_element_filter_removes(self): + """ + A single element not matching the filter is removed. + """ + result = Stream.from_iterable([5]).filter(lambda x: x > 10).collect() + + assert () == result + + def test_none_values_in_stream(self): + """ + None values are valid stream elements. + """ + result = Stream.from_iterable([None, 1, None]).collect() + + assert (None, 1, None) == result + + def test_nested_streams(self): + """ + Streams can contain other streams as elements. + """ + inner = Stream.from_iterable([1, 2]) + outer = Stream.from_iterable([inner, inner]) + result = outer.collect() + + assert 2 == len(result) + assert (1, 2) == result[0].collect() + + def test_large_stream(self): + """ + A large stream processes without error. + """ + result = ( + Stream.from_iterable(range(10000)) + .map(lambda x: x + 1) + .filter(lambda x: x % 1000 == 0) + .collect() + ) + + assert ( + 1000, + 2000, + 3000, + 4000, + 5000, + 6000, + 7000, + 8000, + 9000, + 10000, + ) == result + + def test_fold_with_workers_threads(self): + """ + Fold with workers>1 and use_threads produces the same result. + """ + result = Stream.from_iterable(range(100)).fold( + 0, lambda acc, x: acc + x, workers=2, use_threads=True + ) + + assert 4950 == result diff --git a/packages/codeflash-core/tests/test_model.py b/packages/codeflash-core/tests/test_model.py new file mode 100644 index 0000000..1fca1c4 --- /dev/null +++ b/packages/codeflash-core/tests/test_model.py @@ -0,0 +1,665 @@ +from __future__ import annotations + +import attrs +import pytest + +from codeflash_core._model import ( + _NS_PER_DAY, + _NS_PER_HOUR, + _NS_PER_MIN, + _NS_PER_MS, + _NS_PER_S, + _NS_PER_US, + BenchmarkDetail, + Candidate, + FileDiffContent, + OptimizationRequest, + OptimizationReviewResult, + PrComment, + humanize_runtime, +) + + +class TestTimeConstants: + """Tests for nanosecond conversion constants.""" + + def test_ns_per_us(self): + """ + One microsecond equals 1,000 nanoseconds. + """ + assert 1_000 == _NS_PER_US + + def test_ns_per_ms(self): + """ + One millisecond equals 1,000,000 nanoseconds. + """ + assert 1_000_000 == _NS_PER_MS + + def test_ns_per_s(self): + """ + One second equals 1,000,000,000 nanoseconds. + """ + assert 1_000_000_000 == _NS_PER_S + + def test_ns_per_min(self): + """ + One minute equals 60,000,000,000 nanoseconds. + """ + assert 60_000_000_000 == _NS_PER_MIN + + def test_ns_per_hour(self): + """ + One hour equals 3,600,000,000,000 nanoseconds. + """ + assert 3_600_000_000_000 == _NS_PER_HOUR + + def test_ns_per_day(self): + """ + One day equals 86,400,000,000,000 nanoseconds. + """ + assert 86_400_000_000_000 == _NS_PER_DAY + + def test_constants_are_consistent(self): + """ + Each constant is derivable from the previous one. + """ + assert _NS_PER_MS == _NS_PER_US * 1_000 + assert _NS_PER_S == _NS_PER_MS * 1_000 + assert _NS_PER_MIN == _NS_PER_S * 60 + assert _NS_PER_HOUR == _NS_PER_MIN * 60 + assert _NS_PER_DAY == _NS_PER_HOUR * 24 + + +class TestOptimizationRequest: + """Tests for OptimizationRequest.""" + + def test_required_fields(self): + """ + Construction with only required fields succeeds. + """ + req = OptimizationRequest( + source_code="def f(): pass", + language="python", + language_version="3.12", + ) + assert "def f(): pass" == req.source_code + assert "python" == req.language + assert "3.12" == req.language_version + + def test_default_values(self): + """ + Optional fields use their documented defaults. + """ + req = OptimizationRequest( + source_code="x", + language="python", + language_version="3.12", + ) + assert "" == req.context_code + assert req.is_async is False + assert req.is_numerical_code is None + assert "" == req.codeflash_version + assert req.baseline_runtime_ns is None + assert req.loop_count is None + assert req.line_profiler_results is None + assert req.test_input_examples is None + + def test_all_fields(self): + """ + All fields can be set explicitly. + """ + req = OptimizationRequest( + source_code="def f(): pass", + language="javascript", + language_version="ES2022", + context_code="import os", + is_async=True, + is_numerical_code=True, + codeflash_version="1.0.0", + baseline_runtime_ns=5000, + loop_count=100, + line_profiler_results="Line # Hits", + test_input_examples="example()", + ) + assert "javascript" == req.language + assert "ES2022" == req.language_version + assert "import os" == req.context_code + assert req.is_async is True + assert req.is_numerical_code is True + assert "1.0.0" == req.codeflash_version + assert 5000 == req.baseline_runtime_ns + assert 100 == req.loop_count + assert "Line # Hits" == req.line_profiler_results + assert "example()" == req.test_input_examples + + def test_frozen(self): + """ + Mutating a field raises FrozenInstanceError. + """ + req = OptimizationRequest( + source_code="x", + language="python", + language_version="3.12", + ) + with pytest.raises(attrs.exceptions.FrozenInstanceError): + req.source_code = "y" + + def test_empty_source_code(self): + """ + An empty source_code string is allowed. + """ + req = OptimizationRequest( + source_code="", + language="python", + language_version="3.12", + ) + assert "" == req.source_code + + +class TestCandidate: + """Tests for Candidate.""" + + def test_required_fields(self): + """ + Construction with only required fields succeeds. + """ + c = Candidate(code="def f(): pass", explanation="simplified") + assert "def f(): pass" == c.code + assert "simplified" == c.explanation + + def test_default_values(self): + """ + Optional fields default to empty strings. + """ + c = Candidate(code="x", explanation="y") + assert "" == c.candidate_id + assert "" == c.source + assert "" == c.parent_id + assert "" == c.code_markdown + + def test_all_fields(self): + """ + All fields can be set explicitly. + """ + c = Candidate( + code="def f(): pass", + explanation="optimized", + candidate_id="abc-123", + source="ai-service", + parent_id="parent-0", + code_markdown="```python\ndef f(): pass\n```", + ) + assert "abc-123" == c.candidate_id + assert "ai-service" == c.source + assert "parent-0" == c.parent_id + assert "```python\ndef f(): pass\n```" == c.code_markdown + + def test_frozen(self): + """ + Mutating a field raises FrozenInstanceError. + """ + c = Candidate(code="x", explanation="y") + with pytest.raises(attrs.exceptions.FrozenInstanceError): + c.code = "z" + + def test_empty_strings(self): + """ + Empty strings are valid for required fields. + """ + c = Candidate(code="", explanation="") + assert "" == c.code + assert "" == c.explanation + + +class TestOptimizationReviewResult: + """Tests for OptimizationReviewResult.""" + + def test_construction(self): + """ + Both fields are stored correctly. + """ + r = OptimizationReviewResult( + review="high", + explanation="Well-tested optimization.", + ) + assert "high" == r.review + assert "Well-tested optimization." == r.explanation + + def test_frozen(self): + """ + Mutating a field raises FrozenInstanceError. + """ + r = OptimizationReviewResult(review="low", explanation="risky") + with pytest.raises(attrs.exceptions.FrozenInstanceError): + r.review = "high" + + def test_empty_strings(self): + """ + Empty strings are valid for both fields. + """ + r = OptimizationReviewResult(review="", explanation="") + assert "" == r.review + assert "" == r.explanation + + +class TestFileDiffContent: + """Tests for FileDiffContent.""" + + def test_construction(self): + """ + Both old and new content are stored. + """ + d = FileDiffContent(old_content="before", new_content="after") + assert "before" == d.old_content + assert "after" == d.new_content + + def test_frozen(self): + """ + Mutating a field raises FrozenInstanceError. + """ + d = FileDiffContent(old_content="a", new_content="b") + with pytest.raises(attrs.exceptions.FrozenInstanceError): + d.old_content = "c" + + def test_empty_strings(self): + """ + Empty strings are valid for both fields. + """ + d = FileDiffContent(old_content="", new_content="") + assert "" == d.old_content + assert "" == d.new_content + + +class TestBenchmarkDetail: + """Tests for BenchmarkDetail.""" + + @pytest.fixture(name="detail") + def _detail(self): + """ + A sample BenchmarkDetail for testing. + """ + return BenchmarkDetail( + benchmark_name="test_suite", + test_function="test_compute", + original_timing="100ms", + expected_new_timing="50ms", + speedup_percent=50.0, + ) + + def test_construction(self, detail): + """ + All fields are stored correctly. + """ + assert "test_suite" == detail.benchmark_name + assert "test_compute" == detail.test_function + assert "100ms" == detail.original_timing + assert "50ms" == detail.expected_new_timing + assert 50.0 == detail.speedup_percent + + def test_frozen(self, detail): + """ + Mutating a field raises FrozenInstanceError. + """ + with pytest.raises(attrs.exceptions.FrozenInstanceError): + detail.benchmark_name = "other" + + def test_to_string(self, detail): + """ + to_string returns a multi-line summary with the correct format. + """ + result = detail.to_string() + expected = ( + "Original timing for test_suite::test_compute: 100ms\n" + "Expected new timing for test_suite::test_compute: 50ms\n" + "Benchmark speedup for test_suite::test_compute: 50.00%\n" + ) + assert expected == result + + def test_to_string_fractional_speedup(self): + """ + to_string formats fractional speedup percentages to two decimals. + """ + d = BenchmarkDetail( + benchmark_name="bench", + test_function="test_fn", + original_timing="200ns", + expected_new_timing="133ns", + speedup_percent=33.333, + ) + assert "33.33%" in d.to_string() + + def test_to_dict(self, detail): + """ + to_dict returns all fields as a plain dictionary. + """ + expected = { + "benchmark_name": "test_suite", + "test_function": "test_compute", + "original_timing": "100ms", + "expected_new_timing": "50ms", + "speedup_percent": 50.0, + } + assert expected == detail.to_dict() + + def test_to_dict_roundtrip(self, detail): + """ + A BenchmarkDetail can be reconstructed from its to_dict output. + """ + d = detail.to_dict() + reconstructed = BenchmarkDetail(**d) + assert detail == reconstructed + + def test_empty_strings(self): + """ + Empty strings are valid for name and timing fields. + """ + d = BenchmarkDetail( + benchmark_name="", + test_function="", + original_timing="", + expected_new_timing="", + speedup_percent=0.0, + ) + assert "" == d.benchmark_name + assert "" == d.test_function + + +class TestPrComment: + """Tests for PrComment.""" + + @pytest.fixture(name="pr_comment") + def _pr_comment(self): + """ + A sample PrComment for testing. + """ + return PrComment( + optimization_explanation="Replaced loop with vectorized op.", + best_runtime=50_000_000, + original_runtime=100_000_000, + function_name="compute", + relative_file_path="src/compute.py", + speedup_x="2.0x", + speedup_pct="50%", + loop_count=10, + report_table={"test_a": {"before": 100, "after": 50}}, + ) + + def test_construction(self, pr_comment): + """ + All required fields are stored correctly. + """ + assert "Replaced loop with vectorized op." == ( + pr_comment.optimization_explanation + ) + assert 50_000_000 == pr_comment.best_runtime + assert 100_000_000 == pr_comment.original_runtime + assert "compute" == pr_comment.function_name + assert "src/compute.py" == pr_comment.relative_file_path + assert "2.0x" == pr_comment.speedup_x + assert "50%" == pr_comment.speedup_pct + assert 10 == pr_comment.loop_count + assert {"test_a": {"before": 100, "after": 50}} == ( + pr_comment.report_table + ) + + def test_optional_defaults(self, pr_comment): + """ + Optional fields default to None. + """ + assert pr_comment.benchmark_details is None + assert pr_comment.original_async_throughput is None + assert pr_comment.best_async_throughput is None + + def test_frozen(self, pr_comment): + """ + Mutating a field raises FrozenInstanceError. + """ + with pytest.raises(attrs.exceptions.FrozenInstanceError): + pr_comment.function_name = "other" + + def test_to_json_keys(self, pr_comment): + """ + to_json returns the expected set of keys. + """ + result = pr_comment.to_json() + expected_keys = { + "optimization_explanation", + "best_runtime", + "original_runtime", + "function_name", + "file_path", + "speedup_x", + "speedup_pct", + "loop_count", + "report_table", + "benchmark_details", + } + assert expected_keys == set(result.keys()) + + def test_to_json_humanizes_runtimes(self, pr_comment): + """ + to_json converts runtimes via humanize_runtime. + """ + result = pr_comment.to_json() + assert humanize_runtime(50_000_000) == result["best_runtime"] + assert humanize_runtime(100_000_000) == result["original_runtime"] + + def test_to_json_file_path_key(self, pr_comment): + """ + to_json maps relative_file_path to the 'file_path' key. + """ + result = pr_comment.to_json() + assert "src/compute.py" == result["file_path"] + + def test_to_json_no_benchmark_details(self, pr_comment): + """ + to_json sets benchmark_details to None when unset. + """ + result = pr_comment.to_json() + assert result["benchmark_details"] is None + + def test_to_json_with_benchmark_details(self): + """ + to_json includes benchmark_details when provided. + """ + bd = BenchmarkDetail( + benchmark_name="suite", + test_function="test_fn", + original_timing="10ms", + expected_new_timing="5ms", + speedup_percent=50.0, + ) + pc = PrComment( + optimization_explanation="optimized", + best_runtime=5_000_000, + original_runtime=10_000_000, + function_name="fn", + relative_file_path="f.py", + speedup_x="2x", + speedup_pct="50%", + loop_count=1, + report_table={}, + benchmark_details=(bd,), + ) + result = pc.to_json() + assert (bd,) == result["benchmark_details"] + + def test_to_json_empty_benchmark_details_tuple(self): + """ + to_json converts an empty tuple of benchmark_details to None. + """ + pc = PrComment( + optimization_explanation="opt", + best_runtime=1000, + original_runtime=2000, + function_name="fn", + relative_file_path="f.py", + speedup_x="2x", + speedup_pct="50%", + loop_count=1, + report_table={}, + benchmark_details=(), + ) + result = pc.to_json() + assert result["benchmark_details"] is None + + def test_to_json_excludes_async_when_none(self, pr_comment): + """ + to_json omits async throughput keys when both are None. + """ + result = pr_comment.to_json() + assert "original_async_throughput" not in result + assert "best_async_throughput" not in result + + def test_to_json_includes_async_when_set(self): + """ + to_json includes async throughput keys when both are set. + """ + pc = PrComment( + optimization_explanation="opt", + best_runtime=1000, + original_runtime=2000, + function_name="fn", + relative_file_path="f.py", + speedup_x="2x", + speedup_pct="50%", + loop_count=1, + report_table={}, + original_async_throughput=500, + best_async_throughput=1000, + ) + result = pc.to_json() + assert 500 == result["original_async_throughput"] + assert 1000 == result["best_async_throughput"] + + def test_to_json_excludes_async_when_only_original_set(self): + """ + to_json omits async keys when only original_async_throughput is set. + """ + pc = PrComment( + optimization_explanation="opt", + best_runtime=1000, + original_runtime=2000, + function_name="fn", + relative_file_path="f.py", + speedup_x="2x", + speedup_pct="50%", + loop_count=1, + report_table={}, + original_async_throughput=500, + ) + result = pc.to_json() + assert "original_async_throughput" not in result + assert "best_async_throughput" not in result + + +class TestHumanizeRuntime: + """Tests for humanize_runtime.""" + + def test_single_nanosecond(self): + """ + 1 ns uses the singular 'nanosecond' unit. + """ + assert "1.00 nanosecond" == humanize_runtime(1) + + def test_small_nanoseconds(self): + """ + Values below 1 microsecond use the 'nanoseconds' unit. + """ + assert "500 nanoseconds" == humanize_runtime(500) + + def test_microsecond_range(self): + """ + Values in the microsecond range display in microseconds. + """ + result = humanize_runtime(5_000) + assert "microseconds" in result + + def test_millisecond_range(self): + """ + Values in the millisecond range display in milliseconds. + """ + result = humanize_runtime(5_000_000) + assert "milliseconds" in result + + def test_second_range(self): + """ + Values in the second range display in seconds. + """ + result = humanize_runtime(2_000_000_000) + assert "seconds" in result + + def test_minute_range(self): + """ + Values in the minute range display in minutes. + """ + result = humanize_runtime(120_000_000_000) + assert "minutes" in result + + def test_hour_range(self): + """ + Values in the hour range display in hours. + """ + result = humanize_runtime(2 * _NS_PER_HOUR) + assert "hours" in result + + def test_day_range(self): + """ + Values in the day range display in days. + """ + result = humanize_runtime(2 * _NS_PER_DAY) + assert "days" in result + + def test_exact_one_microsecond(self): + """ + Exactly 1 microsecond uses the singular 'microsecond' unit. + """ + result = humanize_runtime(1_000) + assert "microsecond" in result + assert "microseconds" not in result + + def test_exact_one_second(self): + """ + Exactly 1 second uses the singular 'second' unit. + """ + result = humanize_runtime(_NS_PER_S) + assert "second" in result + assert "seconds" not in result + + def test_exact_one_minute(self): + """ + Exactly 1 minute uses the singular 'minute' unit. + """ + result = humanize_runtime(_NS_PER_MIN) + assert "minute" in result + assert "minutes" not in result + + def test_exact_one_hour(self): + """ + Exactly 1 hour uses the singular 'hour' unit. + """ + result = humanize_runtime(_NS_PER_HOUR) + assert "hour" in result + assert "hours" not in result + + def test_exact_one_day(self): + """ + Exactly 1 day uses the singular 'day' unit. + """ + result = humanize_runtime(_NS_PER_DAY) + assert "day" in result + assert "days" not in result + + def test_returns_string(self): + """ + The return type is always a string. + """ + assert isinstance(humanize_runtime(42), str) + + def test_contains_space_between_value_and_unit(self): + """ + The output always has a space between the numeric value and units. + """ + result = humanize_runtime(12345) + parts = result.split(" ") + assert 2 == len(parts) diff --git a/packages/codeflash-core/tests/test_pipeline.py b/packages/codeflash-core/tests/test_pipeline.py new file mode 100644 index 0000000..f583771 --- /dev/null +++ b/packages/codeflash-core/tests/test_pipeline.py @@ -0,0 +1,701 @@ +from __future__ import annotations + +import attrs +import pytest + +from codeflash_core._model import Candidate +from codeflash_core._pipeline import ( + CandidateForest, + CandidateNode, + EvaluationContext, + create_rank_dictionary, + dedup_candidates, + diff_length, + filter_refined_candidates, + performance_gain, + select_best, +) + + +@pytest.fixture(name="candidate") +def _candidate(): + """ + A minimal Candidate for use in tests. + """ + return Candidate( + code="def f(): pass", + explanation="no-op", + candidate_id="c1", + ) + + +@pytest.fixture(name="eval_ctx") +def _eval_ctx(): + """ + A fresh EvaluationContext. + """ + return EvaluationContext() + + +class TestPerformanceGain: + """Tests for performance_gain.""" + + def test_two_x_speedup(self): + """ + Original twice as slow yields a gain of 1.0. + """ + assert 1.0 == performance_gain( + original_runtime_ns=200, optimized_runtime_ns=100 + ) + + def test_no_speedup(self): + """ + Equal runtimes yield a gain of 0.0. + """ + assert 0.0 == performance_gain( + original_runtime_ns=100, optimized_runtime_ns=100 + ) + + def test_slowdown(self): + """ + Optimized slower than original yields a negative gain. + """ + assert -0.5 == performance_gain( + original_runtime_ns=100, optimized_runtime_ns=200 + ) + + def test_optimized_zero_returns_zero(self): + """ + Zero optimized runtime returns 0.0 to avoid division by zero. + """ + assert 0.0 == performance_gain( + original_runtime_ns=100, optimized_runtime_ns=0 + ) + + def test_both_zero(self): + """ + Both runtimes zero returns 0.0. + """ + assert 0.0 == performance_gain( + original_runtime_ns=0, optimized_runtime_ns=0 + ) + + def test_original_zero_optimized_nonzero(self): + """ + Zero original with nonzero optimized yields -1.0. + """ + assert -1.0 == performance_gain( + original_runtime_ns=0, optimized_runtime_ns=100 + ) + + def test_large_speedup(self): + """ + Large runtime difference produces correspondingly large gain. + """ + assert 999.0 == performance_gain( + original_runtime_ns=1_000_000, optimized_runtime_ns=1_000 + ) + + +class TestDiffLength: + """Tests for diff_length.""" + + def test_identical_strings(self): + """ + Identical strings produce a zero-length diff. + """ + assert 0 == diff_length("hello\n", "hello\n") + + def test_empty_strings(self): + """ + Two empty strings produce a zero-length diff. + """ + assert 0 == diff_length("", "") + + def test_one_line_changed(self): + """ + A single changed line produces a non-zero diff. + """ + result = diff_length("line1\nline2\n", "line1\nchanged\n") + assert result > 0 + + def test_addition(self): + """ + Adding a line produces a non-zero diff. + """ + result = diff_length("a\n", "a\nb\n") + assert result > 0 + + def test_deletion(self): + """ + Removing a line produces a non-zero diff. + """ + result = diff_length("a\nb\n", "a\n") + assert result > 0 + + def test_completely_different(self): + """ + Completely different strings produce a larger diff than a small edit. + """ + small = diff_length("aaa\n", "aab\n") + large = diff_length("aaa\nbbb\nccc\n", "xxx\nyyy\nzzz\n") + assert large > small + + +class TestCreateRankDictionary: + """Tests for create_rank_dictionary.""" + + def test_ascending_order(self): + """ + Already-sorted values get ranks 0, 1, 2. + """ + assert {0: 0, 1: 1, 2: 2} == create_rank_dictionary([10, 20, 30]) + + def test_descending_order(self): + """ + Reverse-sorted values get inverted ranks. + """ + assert {0: 2, 1: 1, 2: 0} == create_rank_dictionary([30, 20, 10]) + + def test_single_element(self): + """ + A single element gets rank 0. + """ + assert {0: 0} == create_rank_dictionary([42]) + + def test_empty_list(self): + """ + An empty list returns an empty dict. + """ + assert {} == create_rank_dictionary([]) + + def test_duplicate_values(self): + """ + Duplicates still get distinct ranks based on original index. + """ + result = create_rank_dictionary([5, 5, 5]) + assert {0, 1, 2} == set(result.values()) + + +class TestCandidateNode: + """Tests for CandidateNode.""" + + def test_candidate_id_delegates(self, candidate): + """ + candidate_id property returns the wrapped candidate's id. + """ + node = CandidateNode(candidate=candidate) + assert "c1" == node.candidate_id + + def test_is_leaf_true(self, candidate): + """ + A node with no children is a leaf. + """ + node = CandidateNode(candidate=candidate) + assert node.is_leaf() is True + + def test_is_leaf_false(self, candidate): + """ + A node with children is not a leaf. + """ + parent = CandidateNode(candidate=candidate) + child_cand = Candidate( + code="def g(): pass", + explanation="child", + candidate_id="c2", + ) + child = CandidateNode(candidate=child_cand, parent=parent) + parent.children.append(child) + assert parent.is_leaf() is False + + def test_path_to_root_single(self, candidate): + """ + A root node's path is just itself. + """ + node = CandidateNode(candidate=candidate) + path = node.path_to_root() + assert 1 == len(path) + assert candidate is path[0] + + def test_path_to_root_chain(self): + """ + A three-deep chain returns root-first order. + """ + c1 = Candidate(code="a", explanation="", candidate_id="c1") + c2 = Candidate(code="b", explanation="", candidate_id="c2") + c3 = Candidate(code="c", explanation="", candidate_id="c3") + root = CandidateNode(candidate=c1) + mid = CandidateNode(candidate=c2, parent=root) + leaf = CandidateNode(candidate=c3, parent=mid) + + path = leaf.path_to_root() + assert ["c1", "c2", "c3"] == [c.candidate_id for c in path] + + def test_default_children_empty(self, candidate): + """ + A new node starts with an empty children list. + """ + node = CandidateNode(candidate=candidate) + assert [] == node.children + + +class TestCandidateForest: + """Tests for CandidateForest.""" + + def test_add_and_get(self, candidate): + """ + Adding a candidate allows retrieval by id. + """ + forest = CandidateForest() + forest.add(candidate) + node = forest.get("c1") + assert node is not None + assert "c1" == node.candidate_id + + def test_len(self): + """ + __len__ returns the number of nodes in the forest. + """ + forest = CandidateForest() + assert 0 == len(forest) + forest.add(Candidate(code="a", explanation="", candidate_id="c1")) + assert 1 == len(forest) + + def test_get_missing(self): + """ + Getting a nonexistent id returns None. + """ + forest = CandidateForest() + assert forest.get("nope") is None + + def test_parent_child_linking(self): + """ + A candidate with parent_id links to its parent node. + """ + forest = CandidateForest() + parent = Candidate(code="a", explanation="", candidate_id="p1") + child = Candidate( + code="b", + explanation="", + candidate_id="c1", + parent_id="p1", + ) + forest.add(parent) + forest.add(child) + + child_node = forest.get("c1") + assert child_node is not None + assert child_node.parent is not None + assert "p1" == child_node.parent.candidate_id + + def test_child_added_before_parent(self): + """ + Adding a child before its parent creates a placeholder. + """ + forest = CandidateForest() + child = Candidate( + code="b", + explanation="", + candidate_id="c1", + parent_id="p1", + ) + forest.add(child) + assert 2 == len(forest) + + parent = Candidate(code="a", explanation="", candidate_id="p1") + forest.add(parent) + assert 2 == len(forest) + + parent_node = forest.get("p1") + assert parent_node is not None + assert "a" == parent_node.candidate.code + + def test_multiple_roots(self): + """ + Candidates without parent_id are independent roots. + """ + forest = CandidateForest() + forest.add(Candidate(code="a", explanation="", candidate_id="r1")) + forest.add(Candidate(code="b", explanation="", candidate_id="r2")) + assert 2 == len(forest) + assert forest.get("r1").parent is None + assert forest.get("r2").parent is None + + +class TestEvaluationContext: + """Tests for EvaluationContext.""" + + def test_record_failed(self, eval_ctx): + """ + record_failed sets correct flags and None values. + """ + eval_ctx.record_failed("c1") + assert eval_ctx.is_correct["c1"] is False + assert eval_ctx.optimized_runtimes["c1"] is None + assert eval_ctx.speedup_ratios["c1"] is None + + def test_record_success(self, eval_ctx): + """ + record_success stores runtime, speedup, and correctness. + """ + eval_ctx.record_success("c1", runtime=500.0, speedup=1.5) + assert eval_ctx.is_correct["c1"] is True + assert 500.0 == eval_ctx.optimized_runtimes["c1"] + assert 1.5 == eval_ctx.speedup_ratios["c1"] + + def test_get_speedup_missing(self, eval_ctx): + """ + get_speedup returns None for unknown candidates. + """ + assert eval_ctx.get_speedup("unknown") is None + + def test_get_runtime_missing(self, eval_ctx): + """ + get_runtime returns None for unknown candidates. + """ + assert eval_ctx.get_runtime("unknown") is None + + def test_get_speedup_after_record(self, eval_ctx): + """ + get_speedup returns the recorded value. + """ + eval_ctx.record_success("c1", runtime=100.0, speedup=2.0) + assert 2.0 == eval_ctx.get_speedup("c1") + + def test_get_runtime_after_record(self, eval_ctx): + """ + get_runtime returns the recorded value. + """ + eval_ctx.record_success("c1", runtime=100.0, speedup=2.0) + assert 100.0 == eval_ctx.get_runtime("c1") + + def test_record_line_profile(self, eval_ctx): + """ + record_line_profile stores the result string. + """ + eval_ctx.record_line_profile("c1", "profile output") + assert "profile output" == eval_ctx.line_profiler_results["c1"] + + def test_register_new(self, eval_ctx): + """ + register_new stores code-to-id mapping with diff info. + """ + eval_ctx.register_new( + normalized_code="norm", + candidate_id="c1", + flat_code="def f(): pass", + original_flat_code="def f(): return 1", + ) + entry = eval_ctx.code_to_id["norm"] + assert "c1" == entry["candidate_id"] + assert "def f(): pass" == entry["shorter_code"] + assert entry["diff_len"] > 0 + + def test_handle_duplicate_copies_prior_results(self, eval_ctx): + """ + handle_duplicate propagates prior correctness and runtime. + """ + eval_ctx.record_success("c1", runtime=100.0, speedup=2.0) + eval_ctx.register_new( + normalized_code="norm", + candidate_id="c1", + flat_code="def f(): pass", + original_flat_code="def g(): pass", + ) + + eval_ctx.handle_duplicate( + candidate_id="c2", + normalized_code="norm", + original_flat_code="def g(): pass", + flat_code="def f(): pass", + ) + + assert eval_ctx.is_correct["c2"] is True + assert 100.0 == eval_ctx.optimized_runtimes["c2"] + assert 2.0 == eval_ctx.speedup_ratios["c2"] + + def test_handle_duplicate_updates_shorter_code(self, eval_ctx): + """ + handle_duplicate replaces shorter_code when the new diff is smaller. + """ + eval_ctx.register_new( + normalized_code="norm", + candidate_id="c1", + flat_code="def f():\n x = 1\n return x\n", + original_flat_code="def f():\n return 1\n", + ) + eval_ctx.record_success("c1", runtime=100.0, speedup=2.0) + + eval_ctx.handle_duplicate( + candidate_id="c2", + normalized_code="norm", + original_flat_code="def f():\n return 1\n", + flat_code="def f():\n return 1\n", + ) + + assert ( + "def f():\n return 1\n" + == eval_ctx.code_to_id["norm"]["shorter_code"] + ) + + def test_starts_empty(self, eval_ctx): + """ + A new EvaluationContext has empty containers. + """ + assert {} == eval_ctx.speedup_ratios + assert {} == eval_ctx.optimized_runtimes + assert {} == eval_ctx.is_correct + assert [] == eval_ctx.valid_candidates + + +class TestDedupCandidates: + """Tests for dedup_candidates.""" + + def test_empty_input(self): + """ + An empty list returns an empty list. + """ + result = dedup_candidates( + [], + normalize_fn=lambda c: c, + original_normalized="original", + ) + assert [] == result + + def test_removes_identical_to_original(self): + """ + Candidates matching the original are removed. + """ + c1 = Candidate(code="original", explanation="", candidate_id="c1") + result = dedup_candidates( + [c1], + normalize_fn=lambda c: c, + original_normalized="original", + ) + assert [] == result + + def test_removes_intra_batch_duplicates(self): + """ + Only the first occurrence of a normalized form is kept. + """ + c1 = Candidate(code="opt", explanation="", candidate_id="c1") + c2 = Candidate(code="opt", explanation="", candidate_id="c2") + result = dedup_candidates( + [c1, c2], + normalize_fn=lambda c: c, + original_normalized="original", + ) + assert 1 == len(result) + assert "c1" == result[0].candidate_id + + def test_removes_cross_batch_duplicates(self): + """ + Candidates already in cross_batch are removed. + """ + c1 = Candidate(code="opt", explanation="", candidate_id="c1") + result = dedup_candidates( + [c1], + normalize_fn=lambda c: c, + original_normalized="original", + cross_batch={"opt": {"candidate_id": "prior"}}, + ) + assert [] == result + + def test_keeps_unique_candidates(self): + """ + All unique candidates are preserved in order. + """ + c1 = Candidate(code="a", explanation="", candidate_id="c1") + c2 = Candidate(code="b", explanation="", candidate_id="c2") + result = dedup_candidates( + [c1, c2], + normalize_fn=lambda c: c, + original_normalized="original", + ) + assert 2 == len(result) + assert ["c1", "c2"] == [c.candidate_id for c in result] + + def test_normalize_fn_failure_keeps_candidate(self): + """ + A candidate whose normalization raises is still kept. + """ + + def bad_normalize(code): + raise ValueError("boom") + + c1 = Candidate(code="a", explanation="", candidate_id="c1") + result = dedup_candidates( + [c1], + normalize_fn=bad_normalize, + original_normalized="original", + ) + assert 1 == len(result) + assert "c1" == result[0].candidate_id + + def test_seen_set_is_populated(self): + """ + The seen set accumulates normalized codes across calls. + """ + seen: set[str] = set() + c1 = Candidate(code="a", explanation="", candidate_id="c1") + dedup_candidates( + [c1], + normalize_fn=lambda c: c, + original_normalized="original", + seen=seen, + ) + assert "a" in seen + + def test_seen_set_prevents_second_batch(self): + """ + A code seen in a prior call is treated as a duplicate. + """ + seen: set[str] = {"a"} + c1 = Candidate(code="a", explanation="", candidate_id="c1") + result = dedup_candidates( + [c1], + normalize_fn=lambda c: c, + original_normalized="original", + seen=seen, + ) + assert [] == result + + +class TestFilterRefinedCandidates: + """Tests for filter_refined_candidates.""" + + def test_under_limit_returns_all(self, eval_ctx): + """ + Fewer candidates than max_candidates returns all of them. + """ + forest = CandidateForest() + candidates = [ + Candidate(code="a", explanation="", candidate_id="c1"), + Candidate(code="b", explanation="", candidate_id="c2"), + ] + result = filter_refined_candidates( + candidates, eval_ctx, forest, "original", max_candidates=5 + ) + assert 2 == len(result) + + def test_at_limit_returns_all(self, eval_ctx): + """ + Exactly max_candidates returns all of them. + """ + forest = CandidateForest() + candidates = [ + Candidate(code="a", explanation="", candidate_id=f"c{i}") + for i in range(5) + ] + result = filter_refined_candidates( + candidates, eval_ctx, forest, "original", max_candidates=5 + ) + assert 5 == len(result) + + def test_over_limit_trims(self, eval_ctx): + """ + More candidates than max_candidates trims the result. + """ + forest = CandidateForest() + candidates = [ + Candidate(code=f"code{i}", explanation="", candidate_id=f"c{i}") + for i in range(10) + ] + result = filter_refined_candidates( + candidates, eval_ctx, forest, "original", max_candidates=3 + ) + assert 3 == len(result) + + def test_prefers_lower_parent_runtime(self): + """ + Candidates with faster parents rank higher. + """ + eval_ctx = EvaluationContext() + forest = CandidateForest() + + parent_fast = Candidate(code="pf", explanation="", candidate_id="pf") + parent_slow = Candidate(code="ps", explanation="", candidate_id="ps") + forest.add(parent_fast) + forest.add(parent_slow) + eval_ctx.record_success("pf", runtime=10.0, speedup=5.0) + eval_ctx.record_success("ps", runtime=1000.0, speedup=0.1) + + candidates = [ + Candidate( + code="slow_child", + explanation="", + candidate_id="c_slow", + parent_id="ps", + ), + Candidate( + code="fast_child", + explanation="", + candidate_id="c_fast", + parent_id="pf", + ), + ] * 4 + + for i, c in enumerate(candidates): + candidates[i] = attrs.evolve( + c, candidate_id=f"{c.candidate_id}_{i}" + ) + + result = filter_refined_candidates( + candidates, eval_ctx, forest, "original", max_candidates=2 + ) + assert 2 == len(result) + assert all("c_fast" in r.candidate_id for r in result) + + +class TestSelectBest: + """Tests for select_best.""" + + def test_empty_returns_none(self, eval_ctx): + """ + No candidates returns None. + """ + assert select_best(eval_ctx, 1000, [], []) is None + + def test_single_candidate(self, eval_ctx): + """ + A single candidate is always selected. + """ + eval_ctx.record_success("c1", runtime=50.0, speedup=1.0) + assert "c1" == select_best(eval_ctx, 1000, [10], ["c1"]) + + def test_picks_faster_candidate(self, eval_ctx): + """ + The faster candidate wins when diff lengths are equal. + """ + eval_ctx.record_success("c1", runtime=500.0, speedup=1.0) + eval_ctx.record_success("c2", runtime=100.0, speedup=4.0) + result = select_best(eval_ctx, 1000, [10, 10], ["c1", "c2"]) + assert "c2" == result + + def test_diff_length_contributes_to_score(self, eval_ctx): + """ + Candidate with shorter diff and equal runtime wins when listed first. + """ + eval_ctx.record_success("short_diff", runtime=100.0, speedup=1.0) + eval_ctx.record_success("long_diff", runtime=100.0, speedup=1.0) + result = select_best( + eval_ctx, 1000, [5, 500], ["short_diff", "long_diff"] + ) + assert "short_diff" == result + + def test_runtime_weighted_more_than_diff(self, eval_ctx): + """ + A faster candidate wins even with a longer diff. + """ + eval_ctx.record_success("c1", runtime=500.0, speedup=1.0) + eval_ctx.record_success("c2", runtime=50.0, speedup=9.0) + result = select_best(eval_ctx, 1000, [5, 100], ["c1", "c2"]) + assert "c2" == result + + def test_missing_runtime_uses_original(self, eval_ctx): + """ + A candidate without recorded runtime uses original_runtime_ns. + """ + eval_ctx.record_success("c1", runtime=50.0, speedup=1.0) + result = select_best(eval_ctx, 1000, [10, 10], ["c1", "c2"]) + assert "c1" == result diff --git a/packages/codeflash-python/tests/test_normalizer.py b/packages/codeflash-python/tests/test_normalizer.py new file mode 100644 index 0000000..1fe2225 --- /dev/null +++ b/packages/codeflash-python/tests/test_normalizer.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +import textwrap + +from codeflash_python.analysis._normalizer import normalize_python_code + + +class TestBasicNormalization: + """Tests for local variable renaming to canonical forms.""" + + def test_single_variable_renamed(self): + """ + A single local variable is renamed to var_0. + """ + code = textwrap.dedent("""\ + def f(): + x = 1 + return x + """) + result = normalize_python_code(code) + assert "var_0 = 1" in result + assert "return var_0" in result + + def test_multiple_variables_renamed_sequentially(self): + """ + Multiple local variables are renamed to var_0, var_1, etc. in order of first assignment. + """ + code = textwrap.dedent("""\ + def f(): + alpha = 1 + beta = 2 + return alpha + beta + """) + result = normalize_python_code(code) + assert "var_0 = 1" in result + assert "var_1 = 2" in result + assert "return var_0 + var_1" in result + + def test_same_variable_reused_gets_same_canonical_name(self): + """ + A variable assigned and then loaded multiple times uses the same canonical name throughout. + """ + code = textwrap.dedent("""\ + def f(): + total = 0 + total = total + 1 + return total + """) + result = normalize_python_code(code) + assert "var_0 = 0" in result + assert "var_0 = var_0 + 1" in result + assert "return var_0" in result + + +class TestFunctionNamePreservation: + """Tests that function and class names are not renamed.""" + + def test_function_name_preserved(self): + """ + The function name itself is not renamed. + """ + code = textwrap.dedent("""\ + def compute_sum(): + x = 1 + return x + """) + result = normalize_python_code(code) + assert "def compute_sum():" in result + + def test_class_name_preserved(self): + """ + The class name is not renamed. + """ + code = textwrap.dedent("""\ + class MyClass: + def method(self): + x = 1 + return x + """) + result = normalize_python_code(code) + assert "class MyClass:" in result + + +class TestParameterPreservation: + """Tests that function parameters are not renamed.""" + + def test_positional_parameter_preserved(self): + """ + Positional parameters keep their original names. + """ + code = textwrap.dedent("""\ + def f(value): + x = value + 1 + return x + """) + result = normalize_python_code(code) + assert "value" in result + assert "var_0 = value + 1" in result + + def test_args_and_kwargs_preserved(self): + """ + *args and **kwargs parameters keep their original names. + """ + code = textwrap.dedent("""\ + def f(*args, **kwargs): + x = args + y = kwargs + return x, y + """) + result = normalize_python_code(code) + assert "args" in result + assert "kwargs" in result + + def test_keyword_only_parameter_preserved(self): + """ + Keyword-only parameters keep their original names. + """ + code = textwrap.dedent("""\ + def f(*, key): + x = key + return x + """) + result = normalize_python_code(code) + assert "key" in result + assert "var_0 = key" in result + + def test_self_parameter_preserved(self): + """ + The self parameter on methods is not renamed. + """ + code = textwrap.dedent("""\ + class C: + def method(self): + x = self + return x + """) + result = normalize_python_code(code) + assert "self" in result + + +class TestImportPreservation: + """Tests that imported names are not renamed.""" + + def test_import_name_preserved(self): + """ + Names brought in by import statements are not renamed. + """ + code = textwrap.dedent("""\ + import os + def f(): + x = os.path.join('a', 'b') + return x + """) + result = normalize_python_code(code) + assert "os" in result + + def test_from_import_name_preserved(self): + """ + Names brought in by from-import statements are not renamed. + """ + code = textwrap.dedent("""\ + from os.path import join + def f(): + x = join('a', 'b') + return x + """) + result = normalize_python_code(code) + assert "join" in result + + def test_aliased_import_preserved(self): + """ + Aliased import names are preserved using the alias. + """ + code = textwrap.dedent("""\ + import numpy as np + def f(): + x = np.array([1]) + return x + """) + result = normalize_python_code(code) + assert "np" in result + + +class TestDocstringRemoval: + """Tests for docstring handling.""" + + def test_docstring_removed_by_default(self): + """ + Docstrings are removed when remove_docstrings=True (the default). + """ + code = textwrap.dedent('''\ + def f(): + """This is a docstring.""" + x = 1 + return x + ''') + result = normalize_python_code(code) + assert "docstring" not in result + + def test_module_docstring_removed(self): + """ + Module-level docstrings are removed. + """ + code = textwrap.dedent('''\ + """Module docstring.""" + def f(): + x = 1 + return x + ''') + result = normalize_python_code(code) + assert "Module docstring" not in result + + def test_class_docstring_removed(self): + """ + Class-level docstrings are removed. + """ + code = textwrap.dedent('''\ + class C: + """Class docstring.""" + def method(self): + x = 1 + return x + ''') + result = normalize_python_code(code) + assert "Class docstring" not in result + + +class TestDocstringPreservation: + """Tests for docstring preservation when remove_docstrings=False.""" + + def test_docstring_preserved_when_flag_false(self): + """ + Docstrings are kept when remove_docstrings=False. + """ + code = textwrap.dedent('''\ + def f(): + """This is a docstring.""" + x = 1 + return x + ''') + result = normalize_python_code(code, remove_docstrings=False) + assert "This is a docstring." in result + + def test_all_docstrings_preserved_when_flag_false(self): + """ + Module, class, and function docstrings are all kept when remove_docstrings=False. + """ + code = textwrap.dedent('''\ + """Module doc.""" + class C: + """Class doc.""" + def method(self): + """Method doc.""" + x = 1 + return x + ''') + result = normalize_python_code(code, remove_docstrings=False) + assert "Module doc." in result + assert "Class doc." in result + assert "Method doc." in result + + +class TestSyntaxErrorHandling: + """Tests that invalid Python input is returned unchanged.""" + + def test_syntax_error_returns_original(self): + """ + Code with a syntax error is returned as-is without modification. + """ + code = "def f(\n x = :\n" + assert code == normalize_python_code(code) + + def test_incomplete_code_returns_original(self): + """ + Incomplete code that cannot be parsed is returned unchanged. + """ + code = "def f(" + assert code == normalize_python_code(code) + + +class TestEmptyInput: + """Tests for empty and whitespace-only input.""" + + def test_empty_string(self): + """ + An empty string produces an empty string. + """ + assert "" == normalize_python_code("") + + def test_whitespace_only(self): + """ + Whitespace-only input normalizes to an empty string. + """ + result = normalize_python_code(" \n\n ") + assert "" == result + + +class TestMultipleFunctions: + """Tests for code containing multiple function definitions.""" + + def test_variables_in_separate_functions_get_independent_names(self): + """ + Local variables in different functions are normalized independently per scope. + """ + code = textwrap.dedent("""\ + def first(): + alpha = 1 + return alpha + + def second(): + beta = 2 + return beta + """) + result = normalize_python_code(code) + assert "def first():" in result + assert "def second():" in result + assert "var_0 = 1" in result + assert "var_0 = 2" in result + + def test_nested_function_variables_normalized(self): + """ + Variables in nested functions are normalized, inheriting the parent scope counter. + """ + code = textwrap.dedent("""\ + def outer(): + x = 1 + def inner(): + y = 2 + return y + return x + """) + result = normalize_python_code(code) + assert "def outer():" in result + assert "def inner():" in result + assert "var_0 = 1" in result + assert "var_1 = 2" in result