test: add 262 tests for previously untested core modules
- test_danom_result.py: 58 tests for Ok/Err Result monad - test_danom_stream.py: 65 tests for Stream pipeline operations - test_model.py: 57 tests for core data models and serialization - test_pipeline.py: 59 tests for pipeline utilities and candidate evaluation - test_normalizer.py: 23 tests for code normalization including SyntaxError handling
This commit is contained in:
parent
90a46d732c
commit
fd88580ac8
5 changed files with 2953 additions and 0 deletions
490
packages/codeflash-core/tests/test_danom_result.py
Normal file
490
packages/codeflash-core/tests/test_danom_result.py
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_core.danom.result import Err, Ok, Result
|
||||
|
||||
|
||||
class TestOk:
|
||||
"""Tests for Ok variant of Result."""
|
||||
|
||||
def test_construction_with_value(self):
|
||||
"""
|
||||
Ok wraps the given value.
|
||||
"""
|
||||
ok = Ok(42)
|
||||
|
||||
assert 42 == ok.inner
|
||||
|
||||
def test_construction_default_none(self):
|
||||
"""
|
||||
Ok defaults to None when no value is given.
|
||||
"""
|
||||
ok = Ok()
|
||||
|
||||
assert ok.inner is None
|
||||
|
||||
def test_is_ok_returns_true(self):
|
||||
"""
|
||||
is_ok returns True for Ok.
|
||||
"""
|
||||
assert Ok(1).is_ok() is True
|
||||
|
||||
def test_unwrap_returns_inner(self):
|
||||
"""
|
||||
unwrap returns the wrapped value.
|
||||
"""
|
||||
assert "hello" == Ok("hello").unwrap()
|
||||
|
||||
def test_unwrap_none_value(self):
|
||||
"""
|
||||
unwrap returns None when Ok wraps None.
|
||||
"""
|
||||
assert Ok(None).unwrap() is None
|
||||
|
||||
def test_map_applies_func(self):
|
||||
"""
|
||||
map applies the function to the inner value and wraps in Ok.
|
||||
"""
|
||||
result = Ok(3).map(lambda x: x * 2)
|
||||
|
||||
assert Ok(6) == result
|
||||
|
||||
def test_map_with_extra_args(self):
|
||||
"""
|
||||
map passes extra positional and keyword args to the function.
|
||||
"""
|
||||
result = Ok(10).map(lambda x, y, z=0: x + y + z, 5, z=3)
|
||||
|
||||
assert Ok(18) == result
|
||||
|
||||
def test_map_err_is_noop(self):
|
||||
"""
|
||||
map_err returns self unchanged for Ok.
|
||||
"""
|
||||
ok = Ok(42)
|
||||
result = ok.map_err(lambda e: "transformed")
|
||||
|
||||
assert result is ok
|
||||
|
||||
def test_and_then_returns_func_result(self):
|
||||
"""
|
||||
and_then applies the function and returns its raw result.
|
||||
"""
|
||||
result = Ok(5).and_then(lambda x: Ok(x + 1))
|
||||
|
||||
assert Ok(6) == result
|
||||
|
||||
def test_and_then_with_extra_args(self):
|
||||
"""
|
||||
and_then passes extra positional and keyword args to the function.
|
||||
"""
|
||||
result = Ok(1).and_then(lambda x, y, z=0: Ok(x + y + z), 2, z=3)
|
||||
|
||||
assert Ok(6) == result
|
||||
|
||||
def test_and_then_can_return_err(self):
|
||||
"""
|
||||
and_then can return an Err from the function.
|
||||
"""
|
||||
result = Ok(5).and_then(lambda x: Err("fail"))
|
||||
|
||||
assert Err("fail") == result
|
||||
|
||||
def test_or_else_is_noop(self):
|
||||
"""
|
||||
or_else returns self unchanged for Ok.
|
||||
"""
|
||||
ok = Ok(42)
|
||||
result = ok.or_else(lambda e: Ok(0))
|
||||
|
||||
assert result is ok
|
||||
|
||||
|
||||
class TestErr:
|
||||
"""Tests for Err variant of Result."""
|
||||
|
||||
def test_construction_with_error(self):
|
||||
"""
|
||||
Err wraps the given error value.
|
||||
"""
|
||||
err = Err("bad")
|
||||
|
||||
assert "bad" == err.error
|
||||
|
||||
def test_construction_default_none(self):
|
||||
"""
|
||||
Err defaults to None error when no value is given.
|
||||
"""
|
||||
err = Err()
|
||||
|
||||
assert err.error is None
|
||||
|
||||
def test_is_ok_returns_false(self):
|
||||
"""
|
||||
is_ok returns False for Err.
|
||||
"""
|
||||
assert Err("x").is_ok() is False
|
||||
|
||||
def test_unwrap_raises_wrapped_exception(self):
|
||||
"""
|
||||
unwrap re-raises the wrapped Exception.
|
||||
"""
|
||||
exc = ValueError("boom")
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
Err(exc).unwrap()
|
||||
|
||||
def test_unwrap_non_exception_raises_valueerror(self):
|
||||
"""
|
||||
unwrap raises ValueError when the error is not an Exception.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError, match="Err does not have a caught error"
|
||||
):
|
||||
Err("not an exception").unwrap()
|
||||
|
||||
def test_unwrap_none_error_raises_valueerror(self):
|
||||
"""
|
||||
unwrap raises ValueError when the error is None.
|
||||
"""
|
||||
with pytest.raises(
|
||||
ValueError, match="Err does not have a caught error"
|
||||
):
|
||||
Err(None).unwrap()
|
||||
|
||||
def test_map_is_noop(self):
|
||||
"""
|
||||
map returns self unchanged for Err.
|
||||
"""
|
||||
err = Err("fail")
|
||||
result = err.map(lambda x: x * 2)
|
||||
|
||||
assert result is err
|
||||
|
||||
def test_map_err_applies_func(self):
|
||||
"""
|
||||
map_err applies the function to the error value.
|
||||
"""
|
||||
result = Err("bad").map_err(lambda e: e.upper())
|
||||
|
||||
assert Err("BAD") == result
|
||||
|
||||
def test_map_err_with_extra_args(self):
|
||||
"""
|
||||
map_err passes extra positional and keyword args to the function.
|
||||
"""
|
||||
result = Err(10).map_err(lambda e, x, y=0: e + x + y, 5, y=3)
|
||||
|
||||
assert Err(18) == result
|
||||
|
||||
def test_and_then_is_noop(self):
|
||||
"""
|
||||
and_then returns self unchanged for Err.
|
||||
"""
|
||||
err = Err("fail")
|
||||
result = err.and_then(lambda x: Ok(x + 1))
|
||||
|
||||
assert result is err
|
||||
|
||||
def test_or_else_applies_func(self):
|
||||
"""
|
||||
or_else applies the function to the error and returns its result.
|
||||
"""
|
||||
result = Err("fail").or_else(lambda e: Ok("recovered"))
|
||||
|
||||
assert Ok("recovered") == result
|
||||
|
||||
def test_or_else_with_extra_args(self):
|
||||
"""
|
||||
or_else passes extra positional and keyword args to the function.
|
||||
"""
|
||||
result = Err("x").or_else(
|
||||
lambda e, suffix, sep="-": Ok(e + sep + suffix), "y", sep="+"
|
||||
)
|
||||
|
||||
assert Ok("x+y") == result
|
||||
|
||||
def test_or_else_can_return_err(self):
|
||||
"""
|
||||
or_else can return another Err from the function.
|
||||
"""
|
||||
result = Err("first").or_else(lambda e: Err("second"))
|
||||
|
||||
assert Err("second") == result
|
||||
|
||||
def test_input_args_stored(self):
|
||||
"""
|
||||
input_args are stored on the Err instance.
|
||||
"""
|
||||
args = (("a", "b"), {"key": "val"})
|
||||
err = Err("fail", input_args=args)
|
||||
|
||||
assert args == err.input_args
|
||||
|
||||
def test_input_args_default_empty_tuple(self):
|
||||
"""
|
||||
input_args defaults to an empty tuple.
|
||||
"""
|
||||
assert () == Err("fail").input_args
|
||||
|
||||
def test_traceback_stored(self):
|
||||
"""
|
||||
traceback string is stored on the Err instance.
|
||||
"""
|
||||
err = Err("fail", traceback="line 1\nline 2")
|
||||
|
||||
assert "line 1\nline 2" == err.traceback
|
||||
|
||||
def test_traceback_default_empty_string(self):
|
||||
"""
|
||||
traceback defaults to an empty string.
|
||||
"""
|
||||
assert "" == Err("fail").traceback
|
||||
|
||||
|
||||
class TestErrEquality:
|
||||
"""Tests for Err.__eq__ and __hash__."""
|
||||
|
||||
def test_equal_errs_with_same_exception_type_and_message(self):
|
||||
"""
|
||||
Two Errs with same exception type, message, and input_args are equal.
|
||||
"""
|
||||
assert Err(ValueError("x")) == Err(ValueError("x"))
|
||||
|
||||
def test_unequal_errs_different_message(self):
|
||||
"""
|
||||
Errs with different messages are not equal.
|
||||
"""
|
||||
assert Err(ValueError("x")) != Err(ValueError("y"))
|
||||
|
||||
def test_unequal_errs_different_type(self):
|
||||
"""
|
||||
Errs with different exception types are not equal.
|
||||
"""
|
||||
assert Err(ValueError("x")) != Err(TypeError("x"))
|
||||
|
||||
def test_unequal_errs_different_input_args(self):
|
||||
"""
|
||||
Errs with different input_args are not equal.
|
||||
"""
|
||||
err1 = Err("x", input_args=(("a",), {}))
|
||||
err2 = Err("x", input_args=(("b",), {}))
|
||||
|
||||
assert err1 != err2
|
||||
|
||||
def test_err_not_equal_to_non_err(self):
|
||||
"""
|
||||
Err is not equal to a non-Err object.
|
||||
"""
|
||||
assert Err("x") != "x"
|
||||
assert Err(42) != 42
|
||||
|
||||
def test_equal_errs_with_string_errors(self):
|
||||
"""
|
||||
Two Errs with identical string errors are equal.
|
||||
"""
|
||||
assert Err("fail") == Err("fail")
|
||||
|
||||
def test_hash_consistent_with_equality(self):
|
||||
"""
|
||||
Equal Err instances produce the same hash.
|
||||
"""
|
||||
err1 = Err(ValueError("x"))
|
||||
err2 = Err(ValueError("x"))
|
||||
|
||||
assert hash(err1) == hash(err2)
|
||||
|
||||
def test_hash_differs_for_unequal(self):
|
||||
"""
|
||||
Different Err instances are unlikely to share a hash.
|
||||
"""
|
||||
assert hash(Err("a")) != hash(Err("b"))
|
||||
|
||||
|
||||
class TestOkEquality:
|
||||
"""Tests for Ok equality and hashing (attrs-generated)."""
|
||||
|
||||
def test_equal_ok_values(self):
|
||||
"""
|
||||
Two Ok instances with the same inner value are equal.
|
||||
"""
|
||||
assert Ok(42) == Ok(42)
|
||||
|
||||
def test_unequal_ok_values(self):
|
||||
"""
|
||||
Two Ok instances with different inner values are not equal.
|
||||
"""
|
||||
assert Ok(1) != Ok(2)
|
||||
|
||||
def test_ok_not_equal_to_err(self):
|
||||
"""
|
||||
Ok and Err are never equal.
|
||||
"""
|
||||
assert Ok(1) != Err(1)
|
||||
|
||||
def test_ok_not_equal_to_plain_value(self):
|
||||
"""
|
||||
Ok is not equal to a plain unwrapped value.
|
||||
"""
|
||||
assert Ok(42) != 42
|
||||
|
||||
def test_hash_consistent_with_equality(self):
|
||||
"""
|
||||
Equal Ok instances produce the same hash.
|
||||
"""
|
||||
assert hash(Ok("a")) == hash(Ok("a"))
|
||||
|
||||
def test_ok_usable_in_sets(self):
|
||||
"""
|
||||
Ok instances can be stored in a set.
|
||||
"""
|
||||
s = {Ok(1), Ok(1), Ok(2)}
|
||||
|
||||
assert 2 == len(s)
|
||||
|
||||
|
||||
class TestResultStaticMethods:
|
||||
"""Tests for Result.unit, result_is_ok, and result_unwrap."""
|
||||
|
||||
def test_unit_wraps_in_ok(self):
|
||||
"""
|
||||
Result.unit wraps a value in Ok.
|
||||
"""
|
||||
result = Result.unit(99)
|
||||
|
||||
assert Ok(99) == result
|
||||
|
||||
def test_result_is_ok_for_ok(self):
|
||||
"""
|
||||
result_is_ok returns True for an Ok.
|
||||
"""
|
||||
assert Result.result_is_ok(Ok(1)) is True
|
||||
|
||||
def test_result_is_ok_for_err(self):
|
||||
"""
|
||||
result_is_ok returns False for an Err.
|
||||
"""
|
||||
assert Result.result_is_ok(Err("fail")) is False
|
||||
|
||||
def test_result_unwrap_ok(self):
|
||||
"""
|
||||
result_unwrap returns the inner value for Ok.
|
||||
"""
|
||||
assert 42 == Result.result_unwrap(Ok(42))
|
||||
|
||||
def test_result_unwrap_err_raises(self):
|
||||
"""
|
||||
result_unwrap raises for an Err with an Exception.
|
||||
"""
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
Result.result_unwrap(Err(RuntimeError("boom")))
|
||||
|
||||
|
||||
class TestNestedResult:
|
||||
"""Tests for nested Result values."""
|
||||
|
||||
def test_ok_wrapping_ok(self):
|
||||
"""
|
||||
Ok can wrap another Ok.
|
||||
"""
|
||||
nested = Ok(Ok(42))
|
||||
|
||||
assert Ok(42) == nested.unwrap()
|
||||
assert 42 == nested.unwrap().unwrap()
|
||||
|
||||
def test_ok_wrapping_err(self):
|
||||
"""
|
||||
Ok can wrap an Err.
|
||||
"""
|
||||
nested = Ok(Err("inner"))
|
||||
inner = nested.unwrap()
|
||||
|
||||
assert False is inner.is_ok()
|
||||
assert "inner" == inner.error
|
||||
|
||||
def test_map_on_nested_ok(self):
|
||||
"""
|
||||
map on a nested Ok applies the function to the inner Result.
|
||||
"""
|
||||
nested = Ok(Ok(5))
|
||||
result = nested.map(lambda inner_ok: inner_ok.map(lambda x: x * 2))
|
||||
|
||||
assert Ok(Ok(10)) == result
|
||||
|
||||
|
||||
class TestErrTracebackExtraction:
|
||||
"""Tests for Err traceback detail extraction."""
|
||||
|
||||
def test_details_populated_from_exception_traceback(self):
|
||||
"""
|
||||
details list is populated when error is an Exception with a traceback.
|
||||
"""
|
||||
try:
|
||||
msg = "test error"
|
||||
raise ValueError(msg) # noqa: TRY301
|
||||
except ValueError as exc:
|
||||
err = Err(exc)
|
||||
|
||||
assert len(err.details) > 0
|
||||
assert "file" in err.details[0]
|
||||
assert "func" in err.details[0]
|
||||
assert "line_no" in err.details[0]
|
||||
assert "locals" in err.details[0]
|
||||
|
||||
def test_details_empty_for_non_exception(self):
|
||||
"""
|
||||
details list is empty when error is not an Exception.
|
||||
"""
|
||||
err = Err("just a string")
|
||||
|
||||
assert [] == err.details
|
||||
|
||||
def test_details_empty_for_exception_without_traceback(self):
|
||||
"""
|
||||
details list is empty for an Exception created without raising.
|
||||
"""
|
||||
err = Err(ValueError("no traceback"))
|
||||
|
||||
assert [] == err.details
|
||||
|
||||
|
||||
class TestErrImmutability:
|
||||
"""Tests for frozen behavior of Ok and Err."""
|
||||
|
||||
def test_ok_is_frozen(self):
|
||||
"""
|
||||
Ok instances cannot be mutated.
|
||||
"""
|
||||
ok = Ok(42)
|
||||
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
ok.inner = 99 # type: ignore[misc]
|
||||
|
||||
def test_err_is_frozen(self):
|
||||
"""
|
||||
Err instances cannot be mutated.
|
||||
"""
|
||||
err = Err("fail")
|
||||
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
err.error = "other" # type: ignore[misc]
|
||||
|
||||
|
||||
class TestErrValidators:
|
||||
"""Tests for Err field validators."""
|
||||
|
||||
def test_input_args_must_be_tuple(self):
|
||||
"""
|
||||
input_args rejects non-tuple values.
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
Err("fail", input_args=["not", "a", "tuple"]) # type: ignore[arg-type]
|
||||
|
||||
def test_traceback_must_be_string(self):
|
||||
"""
|
||||
traceback rejects non-string values.
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
Err("fail", traceback=123) # type: ignore[arg-type]
|
||||
761
packages/codeflash-core/tests/test_danom_stream.py
Normal file
761
packages/codeflash-core/tests/test_danom_stream.py
Normal file
|
|
@ -0,0 +1,761 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_core.danom.stream import (
|
||||
_FILTER,
|
||||
_MAP,
|
||||
_TAP,
|
||||
Stream,
|
||||
_apply_fns,
|
||||
_Nothing,
|
||||
_par_apply_fns,
|
||||
)
|
||||
|
||||
|
||||
class TestStreamFromIterable:
|
||||
"""Tests for Stream.from_iterable."""
|
||||
|
||||
def test_from_list(self):
|
||||
"""
|
||||
A list is converted into a Stream with a tuple seq.
|
||||
"""
|
||||
s = Stream.from_iterable([1, 2, 3])
|
||||
|
||||
assert (1, 2, 3) == s.seq
|
||||
|
||||
def test_from_tuple(self):
|
||||
"""
|
||||
A tuple is accepted as an iterable.
|
||||
"""
|
||||
s = Stream.from_iterable((4, 5))
|
||||
|
||||
assert (4, 5) == s.seq
|
||||
|
||||
def test_from_generator(self):
|
||||
"""
|
||||
A generator is consumed and stored as a tuple.
|
||||
"""
|
||||
s = Stream.from_iterable(x * 2 for x in range(3))
|
||||
|
||||
assert (0, 2, 4) == s.seq
|
||||
|
||||
def test_from_empty(self):
|
||||
"""
|
||||
An empty iterable produces an empty stream.
|
||||
"""
|
||||
s = Stream.from_iterable([])
|
||||
|
||||
assert () == s.seq
|
||||
|
||||
def test_from_string(self):
|
||||
"""
|
||||
A string iterable yields individual characters.
|
||||
"""
|
||||
s = Stream.from_iterable("abc")
|
||||
|
||||
assert ("a", "b", "c") == s.seq
|
||||
|
||||
def test_from_range(self):
|
||||
"""
|
||||
A range object is accepted as an iterable.
|
||||
"""
|
||||
s = Stream.from_iterable(range(4))
|
||||
|
||||
assert (0, 1, 2, 3) == s.seq
|
||||
|
||||
def test_ops_default_empty(self):
|
||||
"""
|
||||
A fresh stream has no queued operations.
|
||||
"""
|
||||
s = Stream.from_iterable([1])
|
||||
|
||||
assert () == s.ops
|
||||
|
||||
|
||||
class TestStreamBool:
|
||||
"""Tests for Stream.__bool__."""
|
||||
|
||||
def test_truthy_when_nonempty(self):
|
||||
"""
|
||||
A stream with elements is truthy.
|
||||
"""
|
||||
s = Stream.from_iterable([1])
|
||||
|
||||
assert bool(s) is True
|
||||
|
||||
def test_falsy_when_empty(self):
|
||||
"""
|
||||
An empty stream is falsy.
|
||||
"""
|
||||
s = Stream.from_iterable([])
|
||||
|
||||
assert bool(s) is False
|
||||
|
||||
|
||||
class TestStreamMap:
|
||||
"""Tests for Stream.map."""
|
||||
|
||||
def test_single_map(self):
|
||||
"""
|
||||
A single map function is applied to each element.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2, 3]).map(lambda x: x * 10).collect()
|
||||
)
|
||||
|
||||
assert (10, 20, 30) == result
|
||||
|
||||
def test_multiple_fns_in_one_call(self):
|
||||
"""
|
||||
Multiple functions passed to a single map call are applied sequentially.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2])
|
||||
.map(lambda x: x + 1, lambda x: x * 3)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (6, 9) == result
|
||||
|
||||
def test_map_empty_stream(self):
|
||||
"""
|
||||
Mapping over an empty stream produces an empty tuple.
|
||||
"""
|
||||
result = Stream.from_iterable([]).map(lambda x: x + 1).collect()
|
||||
|
||||
assert () == result
|
||||
|
||||
def test_map_preserves_immutability(self):
|
||||
"""
|
||||
Calling map returns a new stream without mutating the original.
|
||||
"""
|
||||
s1 = Stream.from_iterable([1, 2])
|
||||
s2 = s1.map(lambda x: x * 2)
|
||||
|
||||
assert () == s1.ops
|
||||
assert 1 == len(s2.ops)
|
||||
assert s1 is not s2
|
||||
|
||||
|
||||
class TestStreamFilter:
|
||||
"""Tests for Stream.filter."""
|
||||
|
||||
def test_single_filter(self):
|
||||
"""
|
||||
A single filter predicate removes non-matching elements.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2, 3, 4])
|
||||
.filter(lambda x: x % 2 == 0)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (2, 4) == result
|
||||
|
||||
def test_filter_removes_all(self):
|
||||
"""
|
||||
A predicate that matches nothing produces an empty tuple.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 3, 5]).filter(lambda x: x > 100).collect()
|
||||
)
|
||||
|
||||
assert () == result
|
||||
|
||||
def test_filter_keeps_all(self):
|
||||
"""
|
||||
A predicate that matches everything keeps all elements.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([2, 4, 6])
|
||||
.filter(lambda x: x % 2 == 0)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (2, 4, 6) == result
|
||||
|
||||
def test_multiple_filters(self):
|
||||
"""
|
||||
Multiple filter predicates in one call are applied sequentially.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable(range(20))
|
||||
.filter(lambda x: x % 2 == 0, lambda x: x > 5)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (6, 8, 10, 12, 14, 16, 18) == result
|
||||
|
||||
def test_filter_empty_stream(self):
|
||||
"""
|
||||
Filtering an empty stream produces an empty tuple.
|
||||
"""
|
||||
result = Stream.from_iterable([]).filter(lambda x: True).collect()
|
||||
|
||||
assert () == result
|
||||
|
||||
|
||||
class TestStreamTap:
|
||||
"""Tests for Stream.tap."""
|
||||
|
||||
def test_tap_side_effect(self):
|
||||
"""
|
||||
Tap executes a side effect without altering elements.
|
||||
"""
|
||||
seen = []
|
||||
result = Stream.from_iterable([1, 2, 3]).tap(seen.append).collect()
|
||||
|
||||
assert (1, 2, 3) == result
|
||||
assert [1, 2, 3] == seen
|
||||
|
||||
def test_tap_receives_deepcopy(self):
|
||||
"""
|
||||
Tap receives a deep copy so mutations do not affect the stream.
|
||||
"""
|
||||
originals = [{"a": 1}, {"a": 2}]
|
||||
captured = []
|
||||
|
||||
def mutating_tap(x):
|
||||
x["a"] = 999
|
||||
captured.append(x)
|
||||
|
||||
result = Stream.from_iterable(originals).tap(mutating_tap).collect()
|
||||
|
||||
assert ({"a": 1}, {"a": 2}) == result
|
||||
assert [{"a": 999}, {"a": 999}] == captured
|
||||
|
||||
|
||||
class TestStreamChaining:
|
||||
"""Tests for chaining multiple operations."""
|
||||
|
||||
def test_map_then_filter(self):
|
||||
"""
|
||||
Map followed by filter applies both in order.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2, 3, 4, 5])
|
||||
.map(lambda x: x * 2)
|
||||
.filter(lambda x: x > 4)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (6, 8, 10) == result
|
||||
|
||||
def test_filter_then_map(self):
|
||||
"""
|
||||
Filter followed by map applies both in order.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2, 3, 4, 5])
|
||||
.filter(lambda x: x > 2)
|
||||
.map(lambda x: x**2)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (9, 16, 25) == result
|
||||
|
||||
def test_map_filter_tap_chain(self):
|
||||
"""
|
||||
All three operation types can be chained together.
|
||||
"""
|
||||
tapped = []
|
||||
result = (
|
||||
Stream.from_iterable([1, 2, 3, 4])
|
||||
.map(lambda x: x + 10)
|
||||
.filter(lambda x: x % 2 == 0)
|
||||
.tap(tapped.append)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (12, 14) == result
|
||||
assert [12, 14] == tapped
|
||||
|
||||
def test_multiple_map_calls(self):
|
||||
"""
|
||||
Multiple chained map calls compose sequentially.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2])
|
||||
.map(lambda x: x + 1)
|
||||
.map(lambda x: x * 10)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (20, 30) == result
|
||||
|
||||
|
||||
class TestStreamFold:
|
||||
"""Tests for Stream.fold."""
|
||||
|
||||
def test_sum(self):
|
||||
"""
|
||||
Fold with addition accumulates a sum.
|
||||
"""
|
||||
result = Stream.from_iterable([1, 2, 3, 4]).fold(
|
||||
0, lambda acc, x: acc + x
|
||||
)
|
||||
|
||||
assert 10 == result
|
||||
|
||||
def test_fold_with_operations(self):
|
||||
"""
|
||||
Fold applies queued operations before reducing.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable([1, 2, 3, 4])
|
||||
.filter(lambda x: x % 2 == 0)
|
||||
.map(lambda x: x * 10)
|
||||
.fold(0, lambda acc, x: acc + x)
|
||||
)
|
||||
|
||||
assert 60 == result
|
||||
|
||||
def test_fold_empty_stream(self):
|
||||
"""
|
||||
Folding an empty stream returns the initial value.
|
||||
"""
|
||||
result = Stream.from_iterable([]).fold(42, lambda acc, x: acc + x)
|
||||
|
||||
assert 42 == result
|
||||
|
||||
def test_fold_string_concat(self):
|
||||
"""
|
||||
Fold can concatenate strings.
|
||||
"""
|
||||
result = Stream.from_iterable(["a", "b", "c"]).fold(
|
||||
"", lambda acc, x: acc + x
|
||||
)
|
||||
|
||||
assert "abc" == result
|
||||
|
||||
|
||||
class TestStreamPartition:
|
||||
"""Tests for Stream.partition."""
|
||||
|
||||
def test_basic_partition(self):
|
||||
"""
|
||||
Partition splits elements by the predicate into two streams.
|
||||
"""
|
||||
truthy, falsy = Stream.from_iterable([1, 2, 3, 4, 5]).partition(
|
||||
lambda x: x > 3
|
||||
)
|
||||
|
||||
assert (4, 5) == truthy.collect()
|
||||
assert (1, 2, 3) == falsy.collect()
|
||||
|
||||
def test_partition_all_true(self):
|
||||
"""
|
||||
When all elements match, the false stream is empty.
|
||||
"""
|
||||
truthy, falsy = Stream.from_iterable([2, 4, 6]).partition(
|
||||
lambda x: x % 2 == 0
|
||||
)
|
||||
|
||||
assert (2, 4, 6) == truthy.collect()
|
||||
assert () == falsy.collect()
|
||||
|
||||
def test_partition_all_false(self):
|
||||
"""
|
||||
When no elements match, the true stream is empty.
|
||||
"""
|
||||
truthy, falsy = Stream.from_iterable([1, 3, 5]).partition(
|
||||
lambda x: x % 2 == 0
|
||||
)
|
||||
|
||||
assert () == truthy.collect()
|
||||
assert (1, 3, 5) == falsy.collect()
|
||||
|
||||
def test_partition_empty(self):
|
||||
"""
|
||||
Partitioning an empty stream yields two empty streams.
|
||||
"""
|
||||
truthy, falsy = Stream.from_iterable([]).partition(lambda x: x > 0)
|
||||
|
||||
assert () == truthy.collect()
|
||||
assert () == falsy.collect()
|
||||
|
||||
def test_partition_applies_pending_ops(self):
|
||||
"""
|
||||
Partition materializes pending operations before splitting.
|
||||
"""
|
||||
truthy, falsy = (
|
||||
Stream.from_iterable([1, 2, 3, 4])
|
||||
.map(lambda x: x * 10)
|
||||
.partition(lambda x: x >= 30)
|
||||
)
|
||||
|
||||
assert (30, 40) == truthy.collect()
|
||||
assert (10, 20) == falsy.collect()
|
||||
|
||||
|
||||
class TestStreamCollect:
|
||||
"""Tests for Stream.collect."""
|
||||
|
||||
def test_collect_no_ops(self):
|
||||
"""
|
||||
Collecting with no operations returns the original elements.
|
||||
"""
|
||||
result = Stream.from_iterable([1, 2, 3]).collect()
|
||||
|
||||
assert (1, 2, 3) == result
|
||||
|
||||
def test_collect_returns_tuple(self):
|
||||
"""
|
||||
Collect always returns a tuple.
|
||||
"""
|
||||
result = Stream.from_iterable([1]).collect()
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
|
||||
def test_collect_single_element(self):
|
||||
"""
|
||||
A single-element stream collects to a one-element tuple.
|
||||
"""
|
||||
result = Stream.from_iterable([42]).map(lambda x: x + 1).collect()
|
||||
|
||||
assert (43,) == result
|
||||
|
||||
|
||||
class TestStreamParCollect:
|
||||
"""Tests for Stream.par_collect."""
|
||||
|
||||
def test_par_collect_threads(self):
|
||||
"""
|
||||
Parallel collection with threads produces the same result as sequential.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable(range(20))
|
||||
.map(lambda x: x * 2)
|
||||
.par_collect(workers=2, use_threads=True)
|
||||
)
|
||||
|
||||
assert tuple(x * 2 for x in range(20)) == result
|
||||
|
||||
def test_par_collect_empty(self):
|
||||
"""
|
||||
Parallel collection of an empty stream returns an empty tuple.
|
||||
"""
|
||||
result = Stream.from_iterable([]).par_collect(
|
||||
workers=2, use_threads=True
|
||||
)
|
||||
|
||||
assert () == result
|
||||
|
||||
def test_par_collect_with_filter(self):
|
||||
"""
|
||||
Parallel collection respects filter operations.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable(range(10))
|
||||
.filter(lambda x: x % 2 == 0)
|
||||
.par_collect(workers=2, use_threads=True)
|
||||
)
|
||||
|
||||
assert (0, 2, 4, 6, 8) == result
|
||||
|
||||
|
||||
class TestStreamAsyncCollect:
|
||||
"""Tests for Stream.async_collect."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_map(self):
|
||||
"""
|
||||
Async map functions are awaited during collection.
|
||||
"""
|
||||
|
||||
async def times_ten(x):
|
||||
return x * 10
|
||||
|
||||
result = await (
|
||||
Stream.from_iterable([1, 2, 3]).map(times_ten).async_collect()
|
||||
)
|
||||
|
||||
assert (10, 20, 30) == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_collect_no_ops(self):
|
||||
"""
|
||||
Async collect with no ops falls back to synchronous collect.
|
||||
"""
|
||||
result = await Stream.from_iterable([1, 2]).async_collect()
|
||||
|
||||
assert (1, 2) == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_filter(self):
|
||||
"""
|
||||
Async filter predicates remove non-matching elements.
|
||||
"""
|
||||
|
||||
async def gt_two(x):
|
||||
return x > 2
|
||||
|
||||
result = await (
|
||||
Stream.from_iterable([1, 2, 3, 4]).filter(gt_two).async_collect()
|
||||
)
|
||||
|
||||
assert (3, 4) == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_tap(self):
|
||||
"""
|
||||
Async tap executes side effects without altering elements.
|
||||
"""
|
||||
seen = []
|
||||
|
||||
async def record(x):
|
||||
seen.append(x)
|
||||
|
||||
result = await (
|
||||
Stream.from_iterable([10, 20]).tap(record).async_collect()
|
||||
)
|
||||
|
||||
assert (10, 20) == result
|
||||
assert [10, 20] == seen
|
||||
|
||||
|
||||
class TestApplyFns:
|
||||
"""Tests for the _apply_fns helper."""
|
||||
|
||||
def test_map_op(self):
|
||||
"""
|
||||
MAP operations transform elements.
|
||||
"""
|
||||
ops = ((_MAP, lambda x: x + 1),)
|
||||
result = tuple(_apply_fns((1, 2, 3), ops))
|
||||
|
||||
assert (2, 3, 4) == result
|
||||
|
||||
def test_filter_op(self):
|
||||
"""
|
||||
FILTER operations remove non-matching elements.
|
||||
"""
|
||||
ops = ((_FILTER, lambda x: x > 2),)
|
||||
result = tuple(_apply_fns((1, 2, 3, 4), ops))
|
||||
|
||||
assert (3, 4) == result
|
||||
|
||||
def test_tap_op(self):
|
||||
"""
|
||||
TAP operations run side effects and pass elements through.
|
||||
"""
|
||||
seen = []
|
||||
ops = ((_TAP, seen.append),)
|
||||
result = tuple(_apply_fns((1, 2), ops))
|
||||
|
||||
assert (1, 2) == result
|
||||
assert 2 == len(seen)
|
||||
|
||||
def test_combined_ops(self):
|
||||
"""
|
||||
Mixed MAP, FILTER, and TAP ops are applied in sequence.
|
||||
"""
|
||||
tapped = []
|
||||
ops = (
|
||||
(_MAP, lambda x: x * 2),
|
||||
(_FILTER, lambda x: x > 4),
|
||||
(_TAP, tapped.append),
|
||||
)
|
||||
result = tuple(_apply_fns((1, 2, 3, 4), ops))
|
||||
|
||||
assert (6, 8) == result
|
||||
assert [6, 8] == tapped
|
||||
|
||||
def test_empty_ops(self):
|
||||
"""
|
||||
No operations yields elements unchanged.
|
||||
"""
|
||||
result = tuple(_apply_fns((1, 2), ()))
|
||||
|
||||
assert (1, 2) == result
|
||||
|
||||
def test_empty_elements(self):
|
||||
"""
|
||||
Empty elements yields nothing regardless of operations.
|
||||
"""
|
||||
ops = ((_MAP, lambda x: x + 1),)
|
||||
result = tuple(_apply_fns((), ops))
|
||||
|
||||
assert () == result
|
||||
|
||||
def test_filter_short_circuits(self):
|
||||
"""
|
||||
Once a filter rejects an element, subsequent ops are skipped.
|
||||
"""
|
||||
map_calls = []
|
||||
|
||||
def tracking_map(x):
|
||||
map_calls.append(x)
|
||||
return x
|
||||
|
||||
ops = (
|
||||
(_FILTER, lambda x: x > 5),
|
||||
(_MAP, tracking_map),
|
||||
)
|
||||
tuple(_apply_fns((1, 10), ops))
|
||||
|
||||
assert [10] == map_calls
|
||||
|
||||
|
||||
class TestParApplyFns:
|
||||
"""Tests for the _par_apply_fns helper."""
|
||||
|
||||
def test_returns_tuple(self):
|
||||
"""
|
||||
_par_apply_fns returns a tuple, not a generator.
|
||||
"""
|
||||
result = _par_apply_fns((1, 2, 3), ((_MAP, lambda x: x + 1),))
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert (2, 3, 4) == result
|
||||
|
||||
def test_empty_elements(self):
|
||||
"""
|
||||
Empty elements produce an empty tuple.
|
||||
"""
|
||||
result = _par_apply_fns((), ((_MAP, lambda x: x),))
|
||||
|
||||
assert () == result
|
||||
|
||||
def test_filter_and_map(self):
|
||||
"""
|
||||
Combined filter and map work in the eager variant.
|
||||
"""
|
||||
ops = (
|
||||
(_FILTER, lambda x: x % 2 == 0),
|
||||
(_MAP, lambda x: x * 10),
|
||||
)
|
||||
result = _par_apply_fns((1, 2, 3, 4), ops)
|
||||
|
||||
assert (20, 40) == result
|
||||
|
||||
|
||||
class TestNothing:
|
||||
"""Tests for the _Nothing sentinel."""
|
||||
|
||||
def test_nothing_is_singleton(self):
|
||||
"""
|
||||
NOTHING is the sole member of the _Nothing enum.
|
||||
"""
|
||||
assert 1 == len(_Nothing)
|
||||
assert _Nothing.NOTHING is _Nothing.NOTHING
|
||||
|
||||
def test_nothing_is_not_equal_to_none(self):
|
||||
"""
|
||||
NOTHING is distinct from None.
|
||||
"""
|
||||
assert _Nothing.NOTHING != None # noqa: E711
|
||||
|
||||
|
||||
class TestStreamImmutability:
|
||||
"""Tests for Stream's frozen/immutable behavior."""
|
||||
|
||||
def test_cannot_set_seq(self):
|
||||
"""
|
||||
Assigning to seq on a frozen stream raises an error.
|
||||
"""
|
||||
s = Stream.from_iterable([1, 2])
|
||||
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
s.seq = (3, 4)
|
||||
|
||||
def test_cannot_set_ops(self):
|
||||
"""
|
||||
Assigning to ops on a frozen stream raises an error.
|
||||
"""
|
||||
s = Stream.from_iterable([1, 2])
|
||||
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
s.ops = ()
|
||||
|
||||
def test_seq_must_be_tuple(self):
|
||||
"""
|
||||
Constructing a stream with a non-tuple seq raises TypeError.
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
Stream(seq=[1, 2, 3])
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases and boundary conditions."""
|
||||
|
||||
def test_single_element_map(self):
|
||||
"""
|
||||
A stream with one element processes map correctly.
|
||||
"""
|
||||
result = Stream.from_iterable([5]).map(lambda x: x * 3).collect()
|
||||
|
||||
assert (15,) == result
|
||||
|
||||
def test_single_element_filter_keeps(self):
|
||||
"""
|
||||
A single element matching the filter is kept.
|
||||
"""
|
||||
result = Stream.from_iterable([5]).filter(lambda x: x > 0).collect()
|
||||
|
||||
assert (5,) == result
|
||||
|
||||
def test_single_element_filter_removes(self):
|
||||
"""
|
||||
A single element not matching the filter is removed.
|
||||
"""
|
||||
result = Stream.from_iterable([5]).filter(lambda x: x > 10).collect()
|
||||
|
||||
assert () == result
|
||||
|
||||
def test_none_values_in_stream(self):
|
||||
"""
|
||||
None values are valid stream elements.
|
||||
"""
|
||||
result = Stream.from_iterable([None, 1, None]).collect()
|
||||
|
||||
assert (None, 1, None) == result
|
||||
|
||||
def test_nested_streams(self):
|
||||
"""
|
||||
Streams can contain other streams as elements.
|
||||
"""
|
||||
inner = Stream.from_iterable([1, 2])
|
||||
outer = Stream.from_iterable([inner, inner])
|
||||
result = outer.collect()
|
||||
|
||||
assert 2 == len(result)
|
||||
assert (1, 2) == result[0].collect()
|
||||
|
||||
def test_large_stream(self):
|
||||
"""
|
||||
A large stream processes without error.
|
||||
"""
|
||||
result = (
|
||||
Stream.from_iterable(range(10000))
|
||||
.map(lambda x: x + 1)
|
||||
.filter(lambda x: x % 1000 == 0)
|
||||
.collect()
|
||||
)
|
||||
|
||||
assert (
|
||||
1000,
|
||||
2000,
|
||||
3000,
|
||||
4000,
|
||||
5000,
|
||||
6000,
|
||||
7000,
|
||||
8000,
|
||||
9000,
|
||||
10000,
|
||||
) == result
|
||||
|
||||
def test_fold_with_workers_threads(self):
|
||||
"""
|
||||
Fold with workers>1 and use_threads produces the same result.
|
||||
"""
|
||||
result = Stream.from_iterable(range(100)).fold(
|
||||
0, lambda acc, x: acc + x, workers=2, use_threads=True
|
||||
)
|
||||
|
||||
assert 4950 == result
|
||||
665
packages/codeflash-core/tests/test_model.py
Normal file
665
packages/codeflash-core/tests/test_model.py
Normal file
|
|
@ -0,0 +1,665 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_core._model import (
|
||||
_NS_PER_DAY,
|
||||
_NS_PER_HOUR,
|
||||
_NS_PER_MIN,
|
||||
_NS_PER_MS,
|
||||
_NS_PER_S,
|
||||
_NS_PER_US,
|
||||
BenchmarkDetail,
|
||||
Candidate,
|
||||
FileDiffContent,
|
||||
OptimizationRequest,
|
||||
OptimizationReviewResult,
|
||||
PrComment,
|
||||
humanize_runtime,
|
||||
)
|
||||
|
||||
|
||||
class TestTimeConstants:
|
||||
"""Tests for nanosecond conversion constants."""
|
||||
|
||||
def test_ns_per_us(self):
|
||||
"""
|
||||
One microsecond equals 1,000 nanoseconds.
|
||||
"""
|
||||
assert 1_000 == _NS_PER_US
|
||||
|
||||
def test_ns_per_ms(self):
|
||||
"""
|
||||
One millisecond equals 1,000,000 nanoseconds.
|
||||
"""
|
||||
assert 1_000_000 == _NS_PER_MS
|
||||
|
||||
def test_ns_per_s(self):
|
||||
"""
|
||||
One second equals 1,000,000,000 nanoseconds.
|
||||
"""
|
||||
assert 1_000_000_000 == _NS_PER_S
|
||||
|
||||
def test_ns_per_min(self):
|
||||
"""
|
||||
One minute equals 60,000,000,000 nanoseconds.
|
||||
"""
|
||||
assert 60_000_000_000 == _NS_PER_MIN
|
||||
|
||||
def test_ns_per_hour(self):
|
||||
"""
|
||||
One hour equals 3,600,000,000,000 nanoseconds.
|
||||
"""
|
||||
assert 3_600_000_000_000 == _NS_PER_HOUR
|
||||
|
||||
def test_ns_per_day(self):
|
||||
"""
|
||||
One day equals 86,400,000,000,000 nanoseconds.
|
||||
"""
|
||||
assert 86_400_000_000_000 == _NS_PER_DAY
|
||||
|
||||
def test_constants_are_consistent(self):
|
||||
"""
|
||||
Each constant is derivable from the previous one.
|
||||
"""
|
||||
assert _NS_PER_MS == _NS_PER_US * 1_000
|
||||
assert _NS_PER_S == _NS_PER_MS * 1_000
|
||||
assert _NS_PER_MIN == _NS_PER_S * 60
|
||||
assert _NS_PER_HOUR == _NS_PER_MIN * 60
|
||||
assert _NS_PER_DAY == _NS_PER_HOUR * 24
|
||||
|
||||
|
||||
class TestOptimizationRequest:
|
||||
"""Tests for OptimizationRequest."""
|
||||
|
||||
def test_required_fields(self):
|
||||
"""
|
||||
Construction with only required fields succeeds.
|
||||
"""
|
||||
req = OptimizationRequest(
|
||||
source_code="def f(): pass",
|
||||
language="python",
|
||||
language_version="3.12",
|
||||
)
|
||||
assert "def f(): pass" == req.source_code
|
||||
assert "python" == req.language
|
||||
assert "3.12" == req.language_version
|
||||
|
||||
def test_default_values(self):
|
||||
"""
|
||||
Optional fields use their documented defaults.
|
||||
"""
|
||||
req = OptimizationRequest(
|
||||
source_code="x",
|
||||
language="python",
|
||||
language_version="3.12",
|
||||
)
|
||||
assert "" == req.context_code
|
||||
assert req.is_async is False
|
||||
assert req.is_numerical_code is None
|
||||
assert "" == req.codeflash_version
|
||||
assert req.baseline_runtime_ns is None
|
||||
assert req.loop_count is None
|
||||
assert req.line_profiler_results is None
|
||||
assert req.test_input_examples is None
|
||||
|
||||
def test_all_fields(self):
|
||||
"""
|
||||
All fields can be set explicitly.
|
||||
"""
|
||||
req = OptimizationRequest(
|
||||
source_code="def f(): pass",
|
||||
language="javascript",
|
||||
language_version="ES2022",
|
||||
context_code="import os",
|
||||
is_async=True,
|
||||
is_numerical_code=True,
|
||||
codeflash_version="1.0.0",
|
||||
baseline_runtime_ns=5000,
|
||||
loop_count=100,
|
||||
line_profiler_results="Line # Hits",
|
||||
test_input_examples="example()",
|
||||
)
|
||||
assert "javascript" == req.language
|
||||
assert "ES2022" == req.language_version
|
||||
assert "import os" == req.context_code
|
||||
assert req.is_async is True
|
||||
assert req.is_numerical_code is True
|
||||
assert "1.0.0" == req.codeflash_version
|
||||
assert 5000 == req.baseline_runtime_ns
|
||||
assert 100 == req.loop_count
|
||||
assert "Line # Hits" == req.line_profiler_results
|
||||
assert "example()" == req.test_input_examples
|
||||
|
||||
def test_frozen(self):
|
||||
"""
|
||||
Mutating a field raises FrozenInstanceError.
|
||||
"""
|
||||
req = OptimizationRequest(
|
||||
source_code="x",
|
||||
language="python",
|
||||
language_version="3.12",
|
||||
)
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
req.source_code = "y"
|
||||
|
||||
def test_empty_source_code(self):
|
||||
"""
|
||||
An empty source_code string is allowed.
|
||||
"""
|
||||
req = OptimizationRequest(
|
||||
source_code="",
|
||||
language="python",
|
||||
language_version="3.12",
|
||||
)
|
||||
assert "" == req.source_code
|
||||
|
||||
|
||||
class TestCandidate:
|
||||
"""Tests for Candidate."""
|
||||
|
||||
def test_required_fields(self):
|
||||
"""
|
||||
Construction with only required fields succeeds.
|
||||
"""
|
||||
c = Candidate(code="def f(): pass", explanation="simplified")
|
||||
assert "def f(): pass" == c.code
|
||||
assert "simplified" == c.explanation
|
||||
|
||||
def test_default_values(self):
|
||||
"""
|
||||
Optional fields default to empty strings.
|
||||
"""
|
||||
c = Candidate(code="x", explanation="y")
|
||||
assert "" == c.candidate_id
|
||||
assert "" == c.source
|
||||
assert "" == c.parent_id
|
||||
assert "" == c.code_markdown
|
||||
|
||||
def test_all_fields(self):
|
||||
"""
|
||||
All fields can be set explicitly.
|
||||
"""
|
||||
c = Candidate(
|
||||
code="def f(): pass",
|
||||
explanation="optimized",
|
||||
candidate_id="abc-123",
|
||||
source="ai-service",
|
||||
parent_id="parent-0",
|
||||
code_markdown="```python\ndef f(): pass\n```",
|
||||
)
|
||||
assert "abc-123" == c.candidate_id
|
||||
assert "ai-service" == c.source
|
||||
assert "parent-0" == c.parent_id
|
||||
assert "```python\ndef f(): pass\n```" == c.code_markdown
|
||||
|
||||
def test_frozen(self):
|
||||
"""
|
||||
Mutating a field raises FrozenInstanceError.
|
||||
"""
|
||||
c = Candidate(code="x", explanation="y")
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
c.code = "z"
|
||||
|
||||
def test_empty_strings(self):
|
||||
"""
|
||||
Empty strings are valid for required fields.
|
||||
"""
|
||||
c = Candidate(code="", explanation="")
|
||||
assert "" == c.code
|
||||
assert "" == c.explanation
|
||||
|
||||
|
||||
class TestOptimizationReviewResult:
|
||||
"""Tests for OptimizationReviewResult."""
|
||||
|
||||
def test_construction(self):
|
||||
"""
|
||||
Both fields are stored correctly.
|
||||
"""
|
||||
r = OptimizationReviewResult(
|
||||
review="high",
|
||||
explanation="Well-tested optimization.",
|
||||
)
|
||||
assert "high" == r.review
|
||||
assert "Well-tested optimization." == r.explanation
|
||||
|
||||
def test_frozen(self):
|
||||
"""
|
||||
Mutating a field raises FrozenInstanceError.
|
||||
"""
|
||||
r = OptimizationReviewResult(review="low", explanation="risky")
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
r.review = "high"
|
||||
|
||||
def test_empty_strings(self):
|
||||
"""
|
||||
Empty strings are valid for both fields.
|
||||
"""
|
||||
r = OptimizationReviewResult(review="", explanation="")
|
||||
assert "" == r.review
|
||||
assert "" == r.explanation
|
||||
|
||||
|
||||
class TestFileDiffContent:
|
||||
"""Tests for FileDiffContent."""
|
||||
|
||||
def test_construction(self):
|
||||
"""
|
||||
Both old and new content are stored.
|
||||
"""
|
||||
d = FileDiffContent(old_content="before", new_content="after")
|
||||
assert "before" == d.old_content
|
||||
assert "after" == d.new_content
|
||||
|
||||
def test_frozen(self):
|
||||
"""
|
||||
Mutating a field raises FrozenInstanceError.
|
||||
"""
|
||||
d = FileDiffContent(old_content="a", new_content="b")
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
d.old_content = "c"
|
||||
|
||||
def test_empty_strings(self):
|
||||
"""
|
||||
Empty strings are valid for both fields.
|
||||
"""
|
||||
d = FileDiffContent(old_content="", new_content="")
|
||||
assert "" == d.old_content
|
||||
assert "" == d.new_content
|
||||
|
||||
|
||||
class TestBenchmarkDetail:
|
||||
"""Tests for BenchmarkDetail."""
|
||||
|
||||
@pytest.fixture(name="detail")
|
||||
def _detail(self):
|
||||
"""
|
||||
A sample BenchmarkDetail for testing.
|
||||
"""
|
||||
return BenchmarkDetail(
|
||||
benchmark_name="test_suite",
|
||||
test_function="test_compute",
|
||||
original_timing="100ms",
|
||||
expected_new_timing="50ms",
|
||||
speedup_percent=50.0,
|
||||
)
|
||||
|
||||
def test_construction(self, detail):
|
||||
"""
|
||||
All fields are stored correctly.
|
||||
"""
|
||||
assert "test_suite" == detail.benchmark_name
|
||||
assert "test_compute" == detail.test_function
|
||||
assert "100ms" == detail.original_timing
|
||||
assert "50ms" == detail.expected_new_timing
|
||||
assert 50.0 == detail.speedup_percent
|
||||
|
||||
def test_frozen(self, detail):
|
||||
"""
|
||||
Mutating a field raises FrozenInstanceError.
|
||||
"""
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
detail.benchmark_name = "other"
|
||||
|
||||
def test_to_string(self, detail):
|
||||
"""
|
||||
to_string returns a multi-line summary with the correct format.
|
||||
"""
|
||||
result = detail.to_string()
|
||||
expected = (
|
||||
"Original timing for test_suite::test_compute: 100ms\n"
|
||||
"Expected new timing for test_suite::test_compute: 50ms\n"
|
||||
"Benchmark speedup for test_suite::test_compute: 50.00%\n"
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
def test_to_string_fractional_speedup(self):
|
||||
"""
|
||||
to_string formats fractional speedup percentages to two decimals.
|
||||
"""
|
||||
d = BenchmarkDetail(
|
||||
benchmark_name="bench",
|
||||
test_function="test_fn",
|
||||
original_timing="200ns",
|
||||
expected_new_timing="133ns",
|
||||
speedup_percent=33.333,
|
||||
)
|
||||
assert "33.33%" in d.to_string()
|
||||
|
||||
def test_to_dict(self, detail):
|
||||
"""
|
||||
to_dict returns all fields as a plain dictionary.
|
||||
"""
|
||||
expected = {
|
||||
"benchmark_name": "test_suite",
|
||||
"test_function": "test_compute",
|
||||
"original_timing": "100ms",
|
||||
"expected_new_timing": "50ms",
|
||||
"speedup_percent": 50.0,
|
||||
}
|
||||
assert expected == detail.to_dict()
|
||||
|
||||
def test_to_dict_roundtrip(self, detail):
|
||||
"""
|
||||
A BenchmarkDetail can be reconstructed from its to_dict output.
|
||||
"""
|
||||
d = detail.to_dict()
|
||||
reconstructed = BenchmarkDetail(**d)
|
||||
assert detail == reconstructed
|
||||
|
||||
def test_empty_strings(self):
|
||||
"""
|
||||
Empty strings are valid for name and timing fields.
|
||||
"""
|
||||
d = BenchmarkDetail(
|
||||
benchmark_name="",
|
||||
test_function="",
|
||||
original_timing="",
|
||||
expected_new_timing="",
|
||||
speedup_percent=0.0,
|
||||
)
|
||||
assert "" == d.benchmark_name
|
||||
assert "" == d.test_function
|
||||
|
||||
|
||||
class TestPrComment:
|
||||
"""Tests for PrComment."""
|
||||
|
||||
@pytest.fixture(name="pr_comment")
|
||||
def _pr_comment(self):
|
||||
"""
|
||||
A sample PrComment for testing.
|
||||
"""
|
||||
return PrComment(
|
||||
optimization_explanation="Replaced loop with vectorized op.",
|
||||
best_runtime=50_000_000,
|
||||
original_runtime=100_000_000,
|
||||
function_name="compute",
|
||||
relative_file_path="src/compute.py",
|
||||
speedup_x="2.0x",
|
||||
speedup_pct="50%",
|
||||
loop_count=10,
|
||||
report_table={"test_a": {"before": 100, "after": 50}},
|
||||
)
|
||||
|
||||
def test_construction(self, pr_comment):
|
||||
"""
|
||||
All required fields are stored correctly.
|
||||
"""
|
||||
assert "Replaced loop with vectorized op." == (
|
||||
pr_comment.optimization_explanation
|
||||
)
|
||||
assert 50_000_000 == pr_comment.best_runtime
|
||||
assert 100_000_000 == pr_comment.original_runtime
|
||||
assert "compute" == pr_comment.function_name
|
||||
assert "src/compute.py" == pr_comment.relative_file_path
|
||||
assert "2.0x" == pr_comment.speedup_x
|
||||
assert "50%" == pr_comment.speedup_pct
|
||||
assert 10 == pr_comment.loop_count
|
||||
assert {"test_a": {"before": 100, "after": 50}} == (
|
||||
pr_comment.report_table
|
||||
)
|
||||
|
||||
def test_optional_defaults(self, pr_comment):
|
||||
"""
|
||||
Optional fields default to None.
|
||||
"""
|
||||
assert pr_comment.benchmark_details is None
|
||||
assert pr_comment.original_async_throughput is None
|
||||
assert pr_comment.best_async_throughput is None
|
||||
|
||||
def test_frozen(self, pr_comment):
|
||||
"""
|
||||
Mutating a field raises FrozenInstanceError.
|
||||
"""
|
||||
with pytest.raises(attrs.exceptions.FrozenInstanceError):
|
||||
pr_comment.function_name = "other"
|
||||
|
||||
def test_to_json_keys(self, pr_comment):
|
||||
"""
|
||||
to_json returns the expected set of keys.
|
||||
"""
|
||||
result = pr_comment.to_json()
|
||||
expected_keys = {
|
||||
"optimization_explanation",
|
||||
"best_runtime",
|
||||
"original_runtime",
|
||||
"function_name",
|
||||
"file_path",
|
||||
"speedup_x",
|
||||
"speedup_pct",
|
||||
"loop_count",
|
||||
"report_table",
|
||||
"benchmark_details",
|
||||
}
|
||||
assert expected_keys == set(result.keys())
|
||||
|
||||
def test_to_json_humanizes_runtimes(self, pr_comment):
|
||||
"""
|
||||
to_json converts runtimes via humanize_runtime.
|
||||
"""
|
||||
result = pr_comment.to_json()
|
||||
assert humanize_runtime(50_000_000) == result["best_runtime"]
|
||||
assert humanize_runtime(100_000_000) == result["original_runtime"]
|
||||
|
||||
def test_to_json_file_path_key(self, pr_comment):
|
||||
"""
|
||||
to_json maps relative_file_path to the 'file_path' key.
|
||||
"""
|
||||
result = pr_comment.to_json()
|
||||
assert "src/compute.py" == result["file_path"]
|
||||
|
||||
def test_to_json_no_benchmark_details(self, pr_comment):
|
||||
"""
|
||||
to_json sets benchmark_details to None when unset.
|
||||
"""
|
||||
result = pr_comment.to_json()
|
||||
assert result["benchmark_details"] is None
|
||||
|
||||
def test_to_json_with_benchmark_details(self):
|
||||
"""
|
||||
to_json includes benchmark_details when provided.
|
||||
"""
|
||||
bd = BenchmarkDetail(
|
||||
benchmark_name="suite",
|
||||
test_function="test_fn",
|
||||
original_timing="10ms",
|
||||
expected_new_timing="5ms",
|
||||
speedup_percent=50.0,
|
||||
)
|
||||
pc = PrComment(
|
||||
optimization_explanation="optimized",
|
||||
best_runtime=5_000_000,
|
||||
original_runtime=10_000_000,
|
||||
function_name="fn",
|
||||
relative_file_path="f.py",
|
||||
speedup_x="2x",
|
||||
speedup_pct="50%",
|
||||
loop_count=1,
|
||||
report_table={},
|
||||
benchmark_details=(bd,),
|
||||
)
|
||||
result = pc.to_json()
|
||||
assert (bd,) == result["benchmark_details"]
|
||||
|
||||
def test_to_json_empty_benchmark_details_tuple(self):
|
||||
"""
|
||||
to_json converts an empty tuple of benchmark_details to None.
|
||||
"""
|
||||
pc = PrComment(
|
||||
optimization_explanation="opt",
|
||||
best_runtime=1000,
|
||||
original_runtime=2000,
|
||||
function_name="fn",
|
||||
relative_file_path="f.py",
|
||||
speedup_x="2x",
|
||||
speedup_pct="50%",
|
||||
loop_count=1,
|
||||
report_table={},
|
||||
benchmark_details=(),
|
||||
)
|
||||
result = pc.to_json()
|
||||
assert result["benchmark_details"] is None
|
||||
|
||||
def test_to_json_excludes_async_when_none(self, pr_comment):
|
||||
"""
|
||||
to_json omits async throughput keys when both are None.
|
||||
"""
|
||||
result = pr_comment.to_json()
|
||||
assert "original_async_throughput" not in result
|
||||
assert "best_async_throughput" not in result
|
||||
|
||||
def test_to_json_includes_async_when_set(self):
|
||||
"""
|
||||
to_json includes async throughput keys when both are set.
|
||||
"""
|
||||
pc = PrComment(
|
||||
optimization_explanation="opt",
|
||||
best_runtime=1000,
|
||||
original_runtime=2000,
|
||||
function_name="fn",
|
||||
relative_file_path="f.py",
|
||||
speedup_x="2x",
|
||||
speedup_pct="50%",
|
||||
loop_count=1,
|
||||
report_table={},
|
||||
original_async_throughput=500,
|
||||
best_async_throughput=1000,
|
||||
)
|
||||
result = pc.to_json()
|
||||
assert 500 == result["original_async_throughput"]
|
||||
assert 1000 == result["best_async_throughput"]
|
||||
|
||||
def test_to_json_excludes_async_when_only_original_set(self):
|
||||
"""
|
||||
to_json omits async keys when only original_async_throughput is set.
|
||||
"""
|
||||
pc = PrComment(
|
||||
optimization_explanation="opt",
|
||||
best_runtime=1000,
|
||||
original_runtime=2000,
|
||||
function_name="fn",
|
||||
relative_file_path="f.py",
|
||||
speedup_x="2x",
|
||||
speedup_pct="50%",
|
||||
loop_count=1,
|
||||
report_table={},
|
||||
original_async_throughput=500,
|
||||
)
|
||||
result = pc.to_json()
|
||||
assert "original_async_throughput" not in result
|
||||
assert "best_async_throughput" not in result
|
||||
|
||||
|
||||
class TestHumanizeRuntime:
|
||||
"""Tests for humanize_runtime."""
|
||||
|
||||
def test_single_nanosecond(self):
|
||||
"""
|
||||
1 ns uses the singular 'nanosecond' unit.
|
||||
"""
|
||||
assert "1.00 nanosecond" == humanize_runtime(1)
|
||||
|
||||
def test_small_nanoseconds(self):
|
||||
"""
|
||||
Values below 1 microsecond use the 'nanoseconds' unit.
|
||||
"""
|
||||
assert "500 nanoseconds" == humanize_runtime(500)
|
||||
|
||||
def test_microsecond_range(self):
|
||||
"""
|
||||
Values in the microsecond range display in microseconds.
|
||||
"""
|
||||
result = humanize_runtime(5_000)
|
||||
assert "microseconds" in result
|
||||
|
||||
def test_millisecond_range(self):
|
||||
"""
|
||||
Values in the millisecond range display in milliseconds.
|
||||
"""
|
||||
result = humanize_runtime(5_000_000)
|
||||
assert "milliseconds" in result
|
||||
|
||||
def test_second_range(self):
|
||||
"""
|
||||
Values in the second range display in seconds.
|
||||
"""
|
||||
result = humanize_runtime(2_000_000_000)
|
||||
assert "seconds" in result
|
||||
|
||||
def test_minute_range(self):
|
||||
"""
|
||||
Values in the minute range display in minutes.
|
||||
"""
|
||||
result = humanize_runtime(120_000_000_000)
|
||||
assert "minutes" in result
|
||||
|
||||
def test_hour_range(self):
|
||||
"""
|
||||
Values in the hour range display in hours.
|
||||
"""
|
||||
result = humanize_runtime(2 * _NS_PER_HOUR)
|
||||
assert "hours" in result
|
||||
|
||||
def test_day_range(self):
|
||||
"""
|
||||
Values in the day range display in days.
|
||||
"""
|
||||
result = humanize_runtime(2 * _NS_PER_DAY)
|
||||
assert "days" in result
|
||||
|
||||
def test_exact_one_microsecond(self):
|
||||
"""
|
||||
Exactly 1 microsecond uses the singular 'microsecond' unit.
|
||||
"""
|
||||
result = humanize_runtime(1_000)
|
||||
assert "microsecond" in result
|
||||
assert "microseconds" not in result
|
||||
|
||||
def test_exact_one_second(self):
|
||||
"""
|
||||
Exactly 1 second uses the singular 'second' unit.
|
||||
"""
|
||||
result = humanize_runtime(_NS_PER_S)
|
||||
assert "second" in result
|
||||
assert "seconds" not in result
|
||||
|
||||
def test_exact_one_minute(self):
|
||||
"""
|
||||
Exactly 1 minute uses the singular 'minute' unit.
|
||||
"""
|
||||
result = humanize_runtime(_NS_PER_MIN)
|
||||
assert "minute" in result
|
||||
assert "minutes" not in result
|
||||
|
||||
def test_exact_one_hour(self):
|
||||
"""
|
||||
Exactly 1 hour uses the singular 'hour' unit.
|
||||
"""
|
||||
result = humanize_runtime(_NS_PER_HOUR)
|
||||
assert "hour" in result
|
||||
assert "hours" not in result
|
||||
|
||||
def test_exact_one_day(self):
|
||||
"""
|
||||
Exactly 1 day uses the singular 'day' unit.
|
||||
"""
|
||||
result = humanize_runtime(_NS_PER_DAY)
|
||||
assert "day" in result
|
||||
assert "days" not in result
|
||||
|
||||
def test_returns_string(self):
|
||||
"""
|
||||
The return type is always a string.
|
||||
"""
|
||||
assert isinstance(humanize_runtime(42), str)
|
||||
|
||||
def test_contains_space_between_value_and_unit(self):
|
||||
"""
|
||||
The output always has a space between the numeric value and units.
|
||||
"""
|
||||
result = humanize_runtime(12345)
|
||||
parts = result.split(" ")
|
||||
assert 2 == len(parts)
|
||||
701
packages/codeflash-core/tests/test_pipeline.py
Normal file
701
packages/codeflash-core/tests/test_pipeline.py
Normal file
|
|
@ -0,0 +1,701 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_core._model import Candidate
|
||||
from codeflash_core._pipeline import (
|
||||
CandidateForest,
|
||||
CandidateNode,
|
||||
EvaluationContext,
|
||||
create_rank_dictionary,
|
||||
dedup_candidates,
|
||||
diff_length,
|
||||
filter_refined_candidates,
|
||||
performance_gain,
|
||||
select_best,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="candidate")
|
||||
def _candidate():
|
||||
"""
|
||||
A minimal Candidate for use in tests.
|
||||
"""
|
||||
return Candidate(
|
||||
code="def f(): pass",
|
||||
explanation="no-op",
|
||||
candidate_id="c1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="eval_ctx")
|
||||
def _eval_ctx():
|
||||
"""
|
||||
A fresh EvaluationContext.
|
||||
"""
|
||||
return EvaluationContext()
|
||||
|
||||
|
||||
class TestPerformanceGain:
|
||||
"""Tests for performance_gain."""
|
||||
|
||||
def test_two_x_speedup(self):
|
||||
"""
|
||||
Original twice as slow yields a gain of 1.0.
|
||||
"""
|
||||
assert 1.0 == performance_gain(
|
||||
original_runtime_ns=200, optimized_runtime_ns=100
|
||||
)
|
||||
|
||||
def test_no_speedup(self):
|
||||
"""
|
||||
Equal runtimes yield a gain of 0.0.
|
||||
"""
|
||||
assert 0.0 == performance_gain(
|
||||
original_runtime_ns=100, optimized_runtime_ns=100
|
||||
)
|
||||
|
||||
def test_slowdown(self):
|
||||
"""
|
||||
Optimized slower than original yields a negative gain.
|
||||
"""
|
||||
assert -0.5 == performance_gain(
|
||||
original_runtime_ns=100, optimized_runtime_ns=200
|
||||
)
|
||||
|
||||
def test_optimized_zero_returns_zero(self):
|
||||
"""
|
||||
Zero optimized runtime returns 0.0 to avoid division by zero.
|
||||
"""
|
||||
assert 0.0 == performance_gain(
|
||||
original_runtime_ns=100, optimized_runtime_ns=0
|
||||
)
|
||||
|
||||
def test_both_zero(self):
|
||||
"""
|
||||
Both runtimes zero returns 0.0.
|
||||
"""
|
||||
assert 0.0 == performance_gain(
|
||||
original_runtime_ns=0, optimized_runtime_ns=0
|
||||
)
|
||||
|
||||
def test_original_zero_optimized_nonzero(self):
|
||||
"""
|
||||
Zero original with nonzero optimized yields -1.0.
|
||||
"""
|
||||
assert -1.0 == performance_gain(
|
||||
original_runtime_ns=0, optimized_runtime_ns=100
|
||||
)
|
||||
|
||||
def test_large_speedup(self):
|
||||
"""
|
||||
Large runtime difference produces correspondingly large gain.
|
||||
"""
|
||||
assert 999.0 == performance_gain(
|
||||
original_runtime_ns=1_000_000, optimized_runtime_ns=1_000
|
||||
)
|
||||
|
||||
|
||||
class TestDiffLength:
|
||||
"""Tests for diff_length."""
|
||||
|
||||
def test_identical_strings(self):
|
||||
"""
|
||||
Identical strings produce a zero-length diff.
|
||||
"""
|
||||
assert 0 == diff_length("hello\n", "hello\n")
|
||||
|
||||
def test_empty_strings(self):
|
||||
"""
|
||||
Two empty strings produce a zero-length diff.
|
||||
"""
|
||||
assert 0 == diff_length("", "")
|
||||
|
||||
def test_one_line_changed(self):
|
||||
"""
|
||||
A single changed line produces a non-zero diff.
|
||||
"""
|
||||
result = diff_length("line1\nline2\n", "line1\nchanged\n")
|
||||
assert result > 0
|
||||
|
||||
def test_addition(self):
|
||||
"""
|
||||
Adding a line produces a non-zero diff.
|
||||
"""
|
||||
result = diff_length("a\n", "a\nb\n")
|
||||
assert result > 0
|
||||
|
||||
def test_deletion(self):
|
||||
"""
|
||||
Removing a line produces a non-zero diff.
|
||||
"""
|
||||
result = diff_length("a\nb\n", "a\n")
|
||||
assert result > 0
|
||||
|
||||
def test_completely_different(self):
|
||||
"""
|
||||
Completely different strings produce a larger diff than a small edit.
|
||||
"""
|
||||
small = diff_length("aaa\n", "aab\n")
|
||||
large = diff_length("aaa\nbbb\nccc\n", "xxx\nyyy\nzzz\n")
|
||||
assert large > small
|
||||
|
||||
|
||||
class TestCreateRankDictionary:
|
||||
"""Tests for create_rank_dictionary."""
|
||||
|
||||
def test_ascending_order(self):
|
||||
"""
|
||||
Already-sorted values get ranks 0, 1, 2.
|
||||
"""
|
||||
assert {0: 0, 1: 1, 2: 2} == create_rank_dictionary([10, 20, 30])
|
||||
|
||||
def test_descending_order(self):
|
||||
"""
|
||||
Reverse-sorted values get inverted ranks.
|
||||
"""
|
||||
assert {0: 2, 1: 1, 2: 0} == create_rank_dictionary([30, 20, 10])
|
||||
|
||||
def test_single_element(self):
|
||||
"""
|
||||
A single element gets rank 0.
|
||||
"""
|
||||
assert {0: 0} == create_rank_dictionary([42])
|
||||
|
||||
def test_empty_list(self):
|
||||
"""
|
||||
An empty list returns an empty dict.
|
||||
"""
|
||||
assert {} == create_rank_dictionary([])
|
||||
|
||||
def test_duplicate_values(self):
|
||||
"""
|
||||
Duplicates still get distinct ranks based on original index.
|
||||
"""
|
||||
result = create_rank_dictionary([5, 5, 5])
|
||||
assert {0, 1, 2} == set(result.values())
|
||||
|
||||
|
||||
class TestCandidateNode:
|
||||
"""Tests for CandidateNode."""
|
||||
|
||||
def test_candidate_id_delegates(self, candidate):
|
||||
"""
|
||||
candidate_id property returns the wrapped candidate's id.
|
||||
"""
|
||||
node = CandidateNode(candidate=candidate)
|
||||
assert "c1" == node.candidate_id
|
||||
|
||||
def test_is_leaf_true(self, candidate):
|
||||
"""
|
||||
A node with no children is a leaf.
|
||||
"""
|
||||
node = CandidateNode(candidate=candidate)
|
||||
assert node.is_leaf() is True
|
||||
|
||||
def test_is_leaf_false(self, candidate):
|
||||
"""
|
||||
A node with children is not a leaf.
|
||||
"""
|
||||
parent = CandidateNode(candidate=candidate)
|
||||
child_cand = Candidate(
|
||||
code="def g(): pass",
|
||||
explanation="child",
|
||||
candidate_id="c2",
|
||||
)
|
||||
child = CandidateNode(candidate=child_cand, parent=parent)
|
||||
parent.children.append(child)
|
||||
assert parent.is_leaf() is False
|
||||
|
||||
def test_path_to_root_single(self, candidate):
|
||||
"""
|
||||
A root node's path is just itself.
|
||||
"""
|
||||
node = CandidateNode(candidate=candidate)
|
||||
path = node.path_to_root()
|
||||
assert 1 == len(path)
|
||||
assert candidate is path[0]
|
||||
|
||||
def test_path_to_root_chain(self):
|
||||
"""
|
||||
A three-deep chain returns root-first order.
|
||||
"""
|
||||
c1 = Candidate(code="a", explanation="", candidate_id="c1")
|
||||
c2 = Candidate(code="b", explanation="", candidate_id="c2")
|
||||
c3 = Candidate(code="c", explanation="", candidate_id="c3")
|
||||
root = CandidateNode(candidate=c1)
|
||||
mid = CandidateNode(candidate=c2, parent=root)
|
||||
leaf = CandidateNode(candidate=c3, parent=mid)
|
||||
|
||||
path = leaf.path_to_root()
|
||||
assert ["c1", "c2", "c3"] == [c.candidate_id for c in path]
|
||||
|
||||
def test_default_children_empty(self, candidate):
|
||||
"""
|
||||
A new node starts with an empty children list.
|
||||
"""
|
||||
node = CandidateNode(candidate=candidate)
|
||||
assert [] == node.children
|
||||
|
||||
|
||||
class TestCandidateForest:
|
||||
"""Tests for CandidateForest."""
|
||||
|
||||
def test_add_and_get(self, candidate):
|
||||
"""
|
||||
Adding a candidate allows retrieval by id.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
forest.add(candidate)
|
||||
node = forest.get("c1")
|
||||
assert node is not None
|
||||
assert "c1" == node.candidate_id
|
||||
|
||||
def test_len(self):
|
||||
"""
|
||||
__len__ returns the number of nodes in the forest.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
assert 0 == len(forest)
|
||||
forest.add(Candidate(code="a", explanation="", candidate_id="c1"))
|
||||
assert 1 == len(forest)
|
||||
|
||||
def test_get_missing(self):
|
||||
"""
|
||||
Getting a nonexistent id returns None.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
assert forest.get("nope") is None
|
||||
|
||||
def test_parent_child_linking(self):
|
||||
"""
|
||||
A candidate with parent_id links to its parent node.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
parent = Candidate(code="a", explanation="", candidate_id="p1")
|
||||
child = Candidate(
|
||||
code="b",
|
||||
explanation="",
|
||||
candidate_id="c1",
|
||||
parent_id="p1",
|
||||
)
|
||||
forest.add(parent)
|
||||
forest.add(child)
|
||||
|
||||
child_node = forest.get("c1")
|
||||
assert child_node is not None
|
||||
assert child_node.parent is not None
|
||||
assert "p1" == child_node.parent.candidate_id
|
||||
|
||||
def test_child_added_before_parent(self):
|
||||
"""
|
||||
Adding a child before its parent creates a placeholder.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
child = Candidate(
|
||||
code="b",
|
||||
explanation="",
|
||||
candidate_id="c1",
|
||||
parent_id="p1",
|
||||
)
|
||||
forest.add(child)
|
||||
assert 2 == len(forest)
|
||||
|
||||
parent = Candidate(code="a", explanation="", candidate_id="p1")
|
||||
forest.add(parent)
|
||||
assert 2 == len(forest)
|
||||
|
||||
parent_node = forest.get("p1")
|
||||
assert parent_node is not None
|
||||
assert "a" == parent_node.candidate.code
|
||||
|
||||
def test_multiple_roots(self):
|
||||
"""
|
||||
Candidates without parent_id are independent roots.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
forest.add(Candidate(code="a", explanation="", candidate_id="r1"))
|
||||
forest.add(Candidate(code="b", explanation="", candidate_id="r2"))
|
||||
assert 2 == len(forest)
|
||||
assert forest.get("r1").parent is None
|
||||
assert forest.get("r2").parent is None
|
||||
|
||||
|
||||
class TestEvaluationContext:
|
||||
"""Tests for EvaluationContext."""
|
||||
|
||||
def test_record_failed(self, eval_ctx):
|
||||
"""
|
||||
record_failed sets correct flags and None values.
|
||||
"""
|
||||
eval_ctx.record_failed("c1")
|
||||
assert eval_ctx.is_correct["c1"] is False
|
||||
assert eval_ctx.optimized_runtimes["c1"] is None
|
||||
assert eval_ctx.speedup_ratios["c1"] is None
|
||||
|
||||
def test_record_success(self, eval_ctx):
|
||||
"""
|
||||
record_success stores runtime, speedup, and correctness.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=500.0, speedup=1.5)
|
||||
assert eval_ctx.is_correct["c1"] is True
|
||||
assert 500.0 == eval_ctx.optimized_runtimes["c1"]
|
||||
assert 1.5 == eval_ctx.speedup_ratios["c1"]
|
||||
|
||||
def test_get_speedup_missing(self, eval_ctx):
|
||||
"""
|
||||
get_speedup returns None for unknown candidates.
|
||||
"""
|
||||
assert eval_ctx.get_speedup("unknown") is None
|
||||
|
||||
def test_get_runtime_missing(self, eval_ctx):
|
||||
"""
|
||||
get_runtime returns None for unknown candidates.
|
||||
"""
|
||||
assert eval_ctx.get_runtime("unknown") is None
|
||||
|
||||
def test_get_speedup_after_record(self, eval_ctx):
|
||||
"""
|
||||
get_speedup returns the recorded value.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
|
||||
assert 2.0 == eval_ctx.get_speedup("c1")
|
||||
|
||||
def test_get_runtime_after_record(self, eval_ctx):
|
||||
"""
|
||||
get_runtime returns the recorded value.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
|
||||
assert 100.0 == eval_ctx.get_runtime("c1")
|
||||
|
||||
def test_record_line_profile(self, eval_ctx):
|
||||
"""
|
||||
record_line_profile stores the result string.
|
||||
"""
|
||||
eval_ctx.record_line_profile("c1", "profile output")
|
||||
assert "profile output" == eval_ctx.line_profiler_results["c1"]
|
||||
|
||||
def test_register_new(self, eval_ctx):
|
||||
"""
|
||||
register_new stores code-to-id mapping with diff info.
|
||||
"""
|
||||
eval_ctx.register_new(
|
||||
normalized_code="norm",
|
||||
candidate_id="c1",
|
||||
flat_code="def f(): pass",
|
||||
original_flat_code="def f(): return 1",
|
||||
)
|
||||
entry = eval_ctx.code_to_id["norm"]
|
||||
assert "c1" == entry["candidate_id"]
|
||||
assert "def f(): pass" == entry["shorter_code"]
|
||||
assert entry["diff_len"] > 0
|
||||
|
||||
def test_handle_duplicate_copies_prior_results(self, eval_ctx):
|
||||
"""
|
||||
handle_duplicate propagates prior correctness and runtime.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
|
||||
eval_ctx.register_new(
|
||||
normalized_code="norm",
|
||||
candidate_id="c1",
|
||||
flat_code="def f(): pass",
|
||||
original_flat_code="def g(): pass",
|
||||
)
|
||||
|
||||
eval_ctx.handle_duplicate(
|
||||
candidate_id="c2",
|
||||
normalized_code="norm",
|
||||
original_flat_code="def g(): pass",
|
||||
flat_code="def f(): pass",
|
||||
)
|
||||
|
||||
assert eval_ctx.is_correct["c2"] is True
|
||||
assert 100.0 == eval_ctx.optimized_runtimes["c2"]
|
||||
assert 2.0 == eval_ctx.speedup_ratios["c2"]
|
||||
|
||||
def test_handle_duplicate_updates_shorter_code(self, eval_ctx):
|
||||
"""
|
||||
handle_duplicate replaces shorter_code when the new diff is smaller.
|
||||
"""
|
||||
eval_ctx.register_new(
|
||||
normalized_code="norm",
|
||||
candidate_id="c1",
|
||||
flat_code="def f():\n x = 1\n return x\n",
|
||||
original_flat_code="def f():\n return 1\n",
|
||||
)
|
||||
eval_ctx.record_success("c1", runtime=100.0, speedup=2.0)
|
||||
|
||||
eval_ctx.handle_duplicate(
|
||||
candidate_id="c2",
|
||||
normalized_code="norm",
|
||||
original_flat_code="def f():\n return 1\n",
|
||||
flat_code="def f():\n return 1\n",
|
||||
)
|
||||
|
||||
assert (
|
||||
"def f():\n return 1\n"
|
||||
== eval_ctx.code_to_id["norm"]["shorter_code"]
|
||||
)
|
||||
|
||||
def test_starts_empty(self, eval_ctx):
|
||||
"""
|
||||
A new EvaluationContext has empty containers.
|
||||
"""
|
||||
assert {} == eval_ctx.speedup_ratios
|
||||
assert {} == eval_ctx.optimized_runtimes
|
||||
assert {} == eval_ctx.is_correct
|
||||
assert [] == eval_ctx.valid_candidates
|
||||
|
||||
|
||||
class TestDedupCandidates:
|
||||
"""Tests for dedup_candidates."""
|
||||
|
||||
def test_empty_input(self):
|
||||
"""
|
||||
An empty list returns an empty list.
|
||||
"""
|
||||
result = dedup_candidates(
|
||||
[],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
)
|
||||
assert [] == result
|
||||
|
||||
def test_removes_identical_to_original(self):
|
||||
"""
|
||||
Candidates matching the original are removed.
|
||||
"""
|
||||
c1 = Candidate(code="original", explanation="", candidate_id="c1")
|
||||
result = dedup_candidates(
|
||||
[c1],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
)
|
||||
assert [] == result
|
||||
|
||||
def test_removes_intra_batch_duplicates(self):
|
||||
"""
|
||||
Only the first occurrence of a normalized form is kept.
|
||||
"""
|
||||
c1 = Candidate(code="opt", explanation="", candidate_id="c1")
|
||||
c2 = Candidate(code="opt", explanation="", candidate_id="c2")
|
||||
result = dedup_candidates(
|
||||
[c1, c2],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
)
|
||||
assert 1 == len(result)
|
||||
assert "c1" == result[0].candidate_id
|
||||
|
||||
def test_removes_cross_batch_duplicates(self):
|
||||
"""
|
||||
Candidates already in cross_batch are removed.
|
||||
"""
|
||||
c1 = Candidate(code="opt", explanation="", candidate_id="c1")
|
||||
result = dedup_candidates(
|
||||
[c1],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
cross_batch={"opt": {"candidate_id": "prior"}},
|
||||
)
|
||||
assert [] == result
|
||||
|
||||
def test_keeps_unique_candidates(self):
|
||||
"""
|
||||
All unique candidates are preserved in order.
|
||||
"""
|
||||
c1 = Candidate(code="a", explanation="", candidate_id="c1")
|
||||
c2 = Candidate(code="b", explanation="", candidate_id="c2")
|
||||
result = dedup_candidates(
|
||||
[c1, c2],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
)
|
||||
assert 2 == len(result)
|
||||
assert ["c1", "c2"] == [c.candidate_id for c in result]
|
||||
|
||||
def test_normalize_fn_failure_keeps_candidate(self):
|
||||
"""
|
||||
A candidate whose normalization raises is still kept.
|
||||
"""
|
||||
|
||||
def bad_normalize(code):
|
||||
raise ValueError("boom")
|
||||
|
||||
c1 = Candidate(code="a", explanation="", candidate_id="c1")
|
||||
result = dedup_candidates(
|
||||
[c1],
|
||||
normalize_fn=bad_normalize,
|
||||
original_normalized="original",
|
||||
)
|
||||
assert 1 == len(result)
|
||||
assert "c1" == result[0].candidate_id
|
||||
|
||||
def test_seen_set_is_populated(self):
|
||||
"""
|
||||
The seen set accumulates normalized codes across calls.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
c1 = Candidate(code="a", explanation="", candidate_id="c1")
|
||||
dedup_candidates(
|
||||
[c1],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
seen=seen,
|
||||
)
|
||||
assert "a" in seen
|
||||
|
||||
def test_seen_set_prevents_second_batch(self):
|
||||
"""
|
||||
A code seen in a prior call is treated as a duplicate.
|
||||
"""
|
||||
seen: set[str] = {"a"}
|
||||
c1 = Candidate(code="a", explanation="", candidate_id="c1")
|
||||
result = dedup_candidates(
|
||||
[c1],
|
||||
normalize_fn=lambda c: c,
|
||||
original_normalized="original",
|
||||
seen=seen,
|
||||
)
|
||||
assert [] == result
|
||||
|
||||
|
||||
class TestFilterRefinedCandidates:
|
||||
"""Tests for filter_refined_candidates."""
|
||||
|
||||
def test_under_limit_returns_all(self, eval_ctx):
|
||||
"""
|
||||
Fewer candidates than max_candidates returns all of them.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
candidates = [
|
||||
Candidate(code="a", explanation="", candidate_id="c1"),
|
||||
Candidate(code="b", explanation="", candidate_id="c2"),
|
||||
]
|
||||
result = filter_refined_candidates(
|
||||
candidates, eval_ctx, forest, "original", max_candidates=5
|
||||
)
|
||||
assert 2 == len(result)
|
||||
|
||||
def test_at_limit_returns_all(self, eval_ctx):
|
||||
"""
|
||||
Exactly max_candidates returns all of them.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
candidates = [
|
||||
Candidate(code="a", explanation="", candidate_id=f"c{i}")
|
||||
for i in range(5)
|
||||
]
|
||||
result = filter_refined_candidates(
|
||||
candidates, eval_ctx, forest, "original", max_candidates=5
|
||||
)
|
||||
assert 5 == len(result)
|
||||
|
||||
def test_over_limit_trims(self, eval_ctx):
|
||||
"""
|
||||
More candidates than max_candidates trims the result.
|
||||
"""
|
||||
forest = CandidateForest()
|
||||
candidates = [
|
||||
Candidate(code=f"code{i}", explanation="", candidate_id=f"c{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
result = filter_refined_candidates(
|
||||
candidates, eval_ctx, forest, "original", max_candidates=3
|
||||
)
|
||||
assert 3 == len(result)
|
||||
|
||||
def test_prefers_lower_parent_runtime(self):
|
||||
"""
|
||||
Candidates with faster parents rank higher.
|
||||
"""
|
||||
eval_ctx = EvaluationContext()
|
||||
forest = CandidateForest()
|
||||
|
||||
parent_fast = Candidate(code="pf", explanation="", candidate_id="pf")
|
||||
parent_slow = Candidate(code="ps", explanation="", candidate_id="ps")
|
||||
forest.add(parent_fast)
|
||||
forest.add(parent_slow)
|
||||
eval_ctx.record_success("pf", runtime=10.0, speedup=5.0)
|
||||
eval_ctx.record_success("ps", runtime=1000.0, speedup=0.1)
|
||||
|
||||
candidates = [
|
||||
Candidate(
|
||||
code="slow_child",
|
||||
explanation="",
|
||||
candidate_id="c_slow",
|
||||
parent_id="ps",
|
||||
),
|
||||
Candidate(
|
||||
code="fast_child",
|
||||
explanation="",
|
||||
candidate_id="c_fast",
|
||||
parent_id="pf",
|
||||
),
|
||||
] * 4
|
||||
|
||||
for i, c in enumerate(candidates):
|
||||
candidates[i] = attrs.evolve(
|
||||
c, candidate_id=f"{c.candidate_id}_{i}"
|
||||
)
|
||||
|
||||
result = filter_refined_candidates(
|
||||
candidates, eval_ctx, forest, "original", max_candidates=2
|
||||
)
|
||||
assert 2 == len(result)
|
||||
assert all("c_fast" in r.candidate_id for r in result)
|
||||
|
||||
|
||||
class TestSelectBest:
|
||||
"""Tests for select_best."""
|
||||
|
||||
def test_empty_returns_none(self, eval_ctx):
|
||||
"""
|
||||
No candidates returns None.
|
||||
"""
|
||||
assert select_best(eval_ctx, 1000, [], []) is None
|
||||
|
||||
def test_single_candidate(self, eval_ctx):
|
||||
"""
|
||||
A single candidate is always selected.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=50.0, speedup=1.0)
|
||||
assert "c1" == select_best(eval_ctx, 1000, [10], ["c1"])
|
||||
|
||||
def test_picks_faster_candidate(self, eval_ctx):
|
||||
"""
|
||||
The faster candidate wins when diff lengths are equal.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=500.0, speedup=1.0)
|
||||
eval_ctx.record_success("c2", runtime=100.0, speedup=4.0)
|
||||
result = select_best(eval_ctx, 1000, [10, 10], ["c1", "c2"])
|
||||
assert "c2" == result
|
||||
|
||||
def test_diff_length_contributes_to_score(self, eval_ctx):
|
||||
"""
|
||||
Candidate with shorter diff and equal runtime wins when listed first.
|
||||
"""
|
||||
eval_ctx.record_success("short_diff", runtime=100.0, speedup=1.0)
|
||||
eval_ctx.record_success("long_diff", runtime=100.0, speedup=1.0)
|
||||
result = select_best(
|
||||
eval_ctx, 1000, [5, 500], ["short_diff", "long_diff"]
|
||||
)
|
||||
assert "short_diff" == result
|
||||
|
||||
def test_runtime_weighted_more_than_diff(self, eval_ctx):
|
||||
"""
|
||||
A faster candidate wins even with a longer diff.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=500.0, speedup=1.0)
|
||||
eval_ctx.record_success("c2", runtime=50.0, speedup=9.0)
|
||||
result = select_best(eval_ctx, 1000, [5, 100], ["c1", "c2"])
|
||||
assert "c2" == result
|
||||
|
||||
def test_missing_runtime_uses_original(self, eval_ctx):
|
||||
"""
|
||||
A candidate without recorded runtime uses original_runtime_ns.
|
||||
"""
|
||||
eval_ctx.record_success("c1", runtime=50.0, speedup=1.0)
|
||||
result = select_best(eval_ctx, 1000, [10, 10], ["c1", "c2"])
|
||||
assert "c1" == result
|
||||
336
packages/codeflash-python/tests/test_normalizer.py
Normal file
336
packages/codeflash-python/tests/test_normalizer.py
Normal file
|
|
@ -0,0 +1,336 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
from codeflash_python.analysis._normalizer import normalize_python_code
|
||||
|
||||
|
||||
class TestBasicNormalization:
|
||||
"""Tests for local variable renaming to canonical forms."""
|
||||
|
||||
def test_single_variable_renamed(self):
|
||||
"""
|
||||
A single local variable is renamed to var_0.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def f():
|
||||
x = 1
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "var_0 = 1" in result
|
||||
assert "return var_0" in result
|
||||
|
||||
def test_multiple_variables_renamed_sequentially(self):
|
||||
"""
|
||||
Multiple local variables are renamed to var_0, var_1, etc. in order of first assignment.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def f():
|
||||
alpha = 1
|
||||
beta = 2
|
||||
return alpha + beta
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "var_0 = 1" in result
|
||||
assert "var_1 = 2" in result
|
||||
assert "return var_0 + var_1" in result
|
||||
|
||||
def test_same_variable_reused_gets_same_canonical_name(self):
|
||||
"""
|
||||
A variable assigned and then loaded multiple times uses the same canonical name throughout.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def f():
|
||||
total = 0
|
||||
total = total + 1
|
||||
return total
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "var_0 = 0" in result
|
||||
assert "var_0 = var_0 + 1" in result
|
||||
assert "return var_0" in result
|
||||
|
||||
|
||||
class TestFunctionNamePreservation:
|
||||
"""Tests that function and class names are not renamed."""
|
||||
|
||||
def test_function_name_preserved(self):
|
||||
"""
|
||||
The function name itself is not renamed.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def compute_sum():
|
||||
x = 1
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "def compute_sum():" in result
|
||||
|
||||
def test_class_name_preserved(self):
|
||||
"""
|
||||
The class name is not renamed.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
class MyClass:
|
||||
def method(self):
|
||||
x = 1
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "class MyClass:" in result
|
||||
|
||||
|
||||
class TestParameterPreservation:
|
||||
"""Tests that function parameters are not renamed."""
|
||||
|
||||
def test_positional_parameter_preserved(self):
|
||||
"""
|
||||
Positional parameters keep their original names.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def f(value):
|
||||
x = value + 1
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "value" in result
|
||||
assert "var_0 = value + 1" in result
|
||||
|
||||
def test_args_and_kwargs_preserved(self):
|
||||
"""
|
||||
*args and **kwargs parameters keep their original names.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def f(*args, **kwargs):
|
||||
x = args
|
||||
y = kwargs
|
||||
return x, y
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "args" in result
|
||||
assert "kwargs" in result
|
||||
|
||||
def test_keyword_only_parameter_preserved(self):
|
||||
"""
|
||||
Keyword-only parameters keep their original names.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def f(*, key):
|
||||
x = key
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "key" in result
|
||||
assert "var_0 = key" in result
|
||||
|
||||
def test_self_parameter_preserved(self):
|
||||
"""
|
||||
The self parameter on methods is not renamed.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
class C:
|
||||
def method(self):
|
||||
x = self
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "self" in result
|
||||
|
||||
|
||||
class TestImportPreservation:
|
||||
"""Tests that imported names are not renamed."""
|
||||
|
||||
def test_import_name_preserved(self):
|
||||
"""
|
||||
Names brought in by import statements are not renamed.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
import os
|
||||
def f():
|
||||
x = os.path.join('a', 'b')
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "os" in result
|
||||
|
||||
def test_from_import_name_preserved(self):
|
||||
"""
|
||||
Names brought in by from-import statements are not renamed.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
from os.path import join
|
||||
def f():
|
||||
x = join('a', 'b')
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "join" in result
|
||||
|
||||
def test_aliased_import_preserved(self):
|
||||
"""
|
||||
Aliased import names are preserved using the alias.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
import numpy as np
|
||||
def f():
|
||||
x = np.array([1])
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "np" in result
|
||||
|
||||
|
||||
class TestDocstringRemoval:
|
||||
"""Tests for docstring handling."""
|
||||
|
||||
def test_docstring_removed_by_default(self):
|
||||
"""
|
||||
Docstrings are removed when remove_docstrings=True (the default).
|
||||
"""
|
||||
code = textwrap.dedent('''\
|
||||
def f():
|
||||
"""This is a docstring."""
|
||||
x = 1
|
||||
return x
|
||||
''')
|
||||
result = normalize_python_code(code)
|
||||
assert "docstring" not in result
|
||||
|
||||
def test_module_docstring_removed(self):
|
||||
"""
|
||||
Module-level docstrings are removed.
|
||||
"""
|
||||
code = textwrap.dedent('''\
|
||||
"""Module docstring."""
|
||||
def f():
|
||||
x = 1
|
||||
return x
|
||||
''')
|
||||
result = normalize_python_code(code)
|
||||
assert "Module docstring" not in result
|
||||
|
||||
def test_class_docstring_removed(self):
|
||||
"""
|
||||
Class-level docstrings are removed.
|
||||
"""
|
||||
code = textwrap.dedent('''\
|
||||
class C:
|
||||
"""Class docstring."""
|
||||
def method(self):
|
||||
x = 1
|
||||
return x
|
||||
''')
|
||||
result = normalize_python_code(code)
|
||||
assert "Class docstring" not in result
|
||||
|
||||
|
||||
class TestDocstringPreservation:
|
||||
"""Tests for docstring preservation when remove_docstrings=False."""
|
||||
|
||||
def test_docstring_preserved_when_flag_false(self):
|
||||
"""
|
||||
Docstrings are kept when remove_docstrings=False.
|
||||
"""
|
||||
code = textwrap.dedent('''\
|
||||
def f():
|
||||
"""This is a docstring."""
|
||||
x = 1
|
||||
return x
|
||||
''')
|
||||
result = normalize_python_code(code, remove_docstrings=False)
|
||||
assert "This is a docstring." in result
|
||||
|
||||
def test_all_docstrings_preserved_when_flag_false(self):
|
||||
"""
|
||||
Module, class, and function docstrings are all kept when remove_docstrings=False.
|
||||
"""
|
||||
code = textwrap.dedent('''\
|
||||
"""Module doc."""
|
||||
class C:
|
||||
"""Class doc."""
|
||||
def method(self):
|
||||
"""Method doc."""
|
||||
x = 1
|
||||
return x
|
||||
''')
|
||||
result = normalize_python_code(code, remove_docstrings=False)
|
||||
assert "Module doc." in result
|
||||
assert "Class doc." in result
|
||||
assert "Method doc." in result
|
||||
|
||||
|
||||
class TestSyntaxErrorHandling:
|
||||
"""Tests that invalid Python input is returned unchanged."""
|
||||
|
||||
def test_syntax_error_returns_original(self):
|
||||
"""
|
||||
Code with a syntax error is returned as-is without modification.
|
||||
"""
|
||||
code = "def f(\n x = :\n"
|
||||
assert code == normalize_python_code(code)
|
||||
|
||||
def test_incomplete_code_returns_original(self):
|
||||
"""
|
||||
Incomplete code that cannot be parsed is returned unchanged.
|
||||
"""
|
||||
code = "def f("
|
||||
assert code == normalize_python_code(code)
|
||||
|
||||
|
||||
class TestEmptyInput:
|
||||
"""Tests for empty and whitespace-only input."""
|
||||
|
||||
def test_empty_string(self):
|
||||
"""
|
||||
An empty string produces an empty string.
|
||||
"""
|
||||
assert "" == normalize_python_code("")
|
||||
|
||||
def test_whitespace_only(self):
|
||||
"""
|
||||
Whitespace-only input normalizes to an empty string.
|
||||
"""
|
||||
result = normalize_python_code(" \n\n ")
|
||||
assert "" == result
|
||||
|
||||
|
||||
class TestMultipleFunctions:
|
||||
"""Tests for code containing multiple function definitions."""
|
||||
|
||||
def test_variables_in_separate_functions_get_independent_names(self):
|
||||
"""
|
||||
Local variables in different functions are normalized independently per scope.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def first():
|
||||
alpha = 1
|
||||
return alpha
|
||||
|
||||
def second():
|
||||
beta = 2
|
||||
return beta
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "def first():" in result
|
||||
assert "def second():" in result
|
||||
assert "var_0 = 1" in result
|
||||
assert "var_0 = 2" in result
|
||||
|
||||
def test_nested_function_variables_normalized(self):
|
||||
"""
|
||||
Variables in nested functions are normalized, inheriting the parent scope counter.
|
||||
"""
|
||||
code = textwrap.dedent("""\
|
||||
def outer():
|
||||
x = 1
|
||||
def inner():
|
||||
y = 2
|
||||
return y
|
||||
return x
|
||||
""")
|
||||
result = normalize_python_code(code)
|
||||
assert "def outer():" in result
|
||||
assert "def inner():" in result
|
||||
assert "var_0 = 1" in result
|
||||
assert "var_1 = 2" in result
|
||||
Loading…
Reference in a new issue