codeflash-agent/packages/codeflash-python/tests/test_function_optimizer.py
Kevin Turcios 3ee9c22c8e
fix: resolve all ruff lint errors across repo (#38)
* fix: resolve all ruff lint errors across repo

Auto-fixed 31 errors (unused imports, formatting, simplifications).
Manually fixed 14 remaining:
- EXE001: removed shebangs from non-executable bench scripts
- C417: replaced map(lambda) with generator expression
- C901/PLR0915: extracted _write_and_instrument_tests from generate_ai_tests
- C901/PLR0912: extracted _parse_toml_addopts and _ini_section_name from modify_addopts
- RUF001/RUF002: replaced ambiguous Unicode chars (en dash, multiplication sign)
- FBT002: made boolean params keyword-only in report functions
- E402: moved `import re` to top of file in security reports

* fix: resolve pre-existing mypy errors across packages

- _testgen.py: annotate `generated` as `str` to avoid no-any-return
- _test_runner.py: use str() for TimeoutExpired stdout/stderr (bytes|str),
  remove unused type: ignore on proc.kill()
- _candidate_eval.py: annotate `speedup` as `float` to avoid no-any-return
  from lazy-loaded performance_gain
2026-04-23 10:22:42 -05:00

198 lines
6.7 KiB
Python

"""Tests for per-function optimization utilities (stage 23b)."""
from __future__ import annotations
import textwrap
from unittest.mock import MagicMock, patch
from codeflash_python._model import FunctionParent
from codeflash_python.pipeline._context import OptimizationContext
from codeflash_python.pipeline._function_optimizer import (
NUMBA_REQUIRED_MODULES,
NUMERICAL_MODULES,
PythonFunctionOptimizer,
is_numerical_code,
resolve_function_ast,
)
class TestIsNumericalCode:
"""Tests for is_numerical_code detection."""
def test_torch_import(self) -> None:
"""Code importing torch is detected as numerical."""
code = textwrap.dedent("""\
import torch
def compute():
return torch.tensor([1, 2])
""")
assert is_numerical_code(code) is True
def test_no_numerical_imports(self) -> None:
"""Code without numerical imports returns False."""
code = textwrap.dedent("""\
import os
def compute():
return os.getcwd()
""")
assert is_numerical_code(code) is False
def test_with_function_name_torch(self) -> None:
"""When given a function name, only that function is checked."""
code = textwrap.dedent("""\
import torch
def uses_torch():
return torch.zeros(5)
def plain():
return 42
""")
assert is_numerical_code(code, function_name="uses_torch") is True
assert is_numerical_code(code, function_name="plain") is False
def test_numba_required_modules_without_numba(self) -> None:
"""numpy/scipy/math alone return False when numba is not installed."""
from codeflash_python.pipeline._function_optimizer import _HAS_NUMBA
code = "import numpy\ndef f(): return numpy.array([1])\n"
if not _HAS_NUMBA:
assert is_numerical_code(code) is False
else:
assert is_numerical_code(code) is True
def test_syntax_error(self) -> None:
"""Syntax errors in code return False."""
assert is_numerical_code("def foo(:") is False
def test_nonexistent_function(self) -> None:
"""A function name not in the code returns False."""
code = "import torch\ndef foo(): return 1\n"
assert is_numerical_code(code, function_name="nonexistent") is False
def test_method_in_class(self) -> None:
"""Class methods using numerical code are detected."""
code = textwrap.dedent("""\
import torch
class Model:
def forward(self):
return torch.zeros(3)
""")
assert is_numerical_code(code, function_name="Model.forward") is True
def test_constants(self) -> None:
"""Module constants are defined."""
assert "numpy" in NUMERICAL_MODULES
assert "torch" in NUMERICAL_MODULES
assert "math" in NUMBA_REQUIRED_MODULES
class TestResolveFunctionAst:
"""Tests for resolve_function_ast."""
def test_top_level(self) -> None:
"""Top-level function is resolved."""
code = "def foo():\n return 1\n"
result = resolve_function_ast(code, "foo", [])
assert result is not None
assert result.name == "foo"
def test_method(self) -> None:
"""Class method is resolved via parent chain."""
code = textwrap.dedent("""\
class MyClass:
def run(self):
return 42
""")
parents = [FunctionParent(name="MyClass", type="ClassDef")]
result = resolve_function_ast(code, "run", parents)
assert result is not None
assert result.name == "run"
def test_missing_returns_none(self) -> None:
"""Missing function returns None."""
code = "x = 1\n"
result = resolve_function_ast(code, "nope", [])
assert result is None
class TestNoGenTests:
"""Tests for --no-gen-tests flag wiring."""
def test_field_defaults_to_false(self) -> None:
"""no_gen_tests defaults to False."""
ctx = OptimizationContext(
plugin=MagicMock(),
project_root=MagicMock(),
test_cfg=MagicMock(),
ai_client=MagicMock(),
)
opt = PythonFunctionOptimizer(ctx=ctx)
assert opt.no_gen_tests is False
def test_field_accepts_true(self) -> None:
"""no_gen_tests=True is stored on the instance."""
ctx = OptimizationContext(
plugin=MagicMock(),
project_root=MagicMock(),
test_cfg=MagicMock(),
ai_client=MagicMock(),
)
opt = PythonFunctionOptimizer(ctx=ctx, no_gen_tests=True)
assert opt.no_gen_tests is True
async def test_skips_ai_test_generation(self) -> None:
"""When no_gen_tests=True, generate_ai_tests is never called."""
_mod = "codeflash_python.pipeline._function_optimizer"
_cls = f"{_mod}.PythonFunctionOptimizer"
fn_input = MagicMock()
fn_input.function.qualified_name = "mod.func"
fn_input.function.parents = []
fn_input.function.function_name = "func"
fn_input.function.is_async = False
with (
patch(
"codeflash_python.pipeline._test_orchestrator.generate_ai_tests"
) as mock_gen,
patch(
"codeflash_python.pipeline._test_orchestrator.instrument_tests_for_function",
return_value=None,
),
patch(
"codeflash_python.pipeline._test_orchestrator.generate_concolic_tests",
return_value=({}, "", None),
),
patch(
"codeflash_python.context.pipeline.get_code_optimization_context",
return_value=MagicMock(),
),
patch(
"codeflash_python.pipeline._module_prep.resolve_python_function_ast",
return_value=None,
),
patch(f"{_mod}.is_numerical_code", return_value=False),
patch(
"codeflash_python.verification._baseline.establish_original_code_baseline"
),
):
ctx = OptimizationContext(
plugin=MagicMock(),
project_root=MagicMock(),
test_cfg=MagicMock(),
ai_client=MagicMock(),
)
opt = PythonFunctionOptimizer(
ctx=ctx,
no_gen_tests=True,
)
# optimize() will exit early at the baseline step since
# test_files is None, but the generate_ai_tests guard
# is checked before that.
await opt.optimize(fn_input)
mock_gen.assert_not_called()