mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
* fix: resolve all ruff lint errors across repo Auto-fixed 31 errors (unused imports, formatting, simplifications). Manually fixed 14 remaining: - EXE001: removed shebangs from non-executable bench scripts - C417: replaced map(lambda) with generator expression - C901/PLR0915: extracted _write_and_instrument_tests from generate_ai_tests - C901/PLR0912: extracted _parse_toml_addopts and _ini_section_name from modify_addopts - RUF001/RUF002: replaced ambiguous Unicode chars (en dash, multiplication sign) - FBT002: made boolean params keyword-only in report functions - E402: moved `import re` to top of file in security reports * fix: resolve pre-existing mypy errors across packages - _testgen.py: annotate `generated` as `str` to avoid no-any-return - _test_runner.py: use str() for TimeoutExpired stdout/stderr (bytes|str), remove unused type: ignore on proc.kill() - _candidate_eval.py: annotate `speedup` as `float` to avoid no-any-return from lazy-loaded performance_gain
198 lines
6.7 KiB
Python
198 lines
6.7 KiB
Python
"""Tests for per-function optimization utilities (stage 23b)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import textwrap
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from codeflash_python._model import FunctionParent
|
|
from codeflash_python.pipeline._context import OptimizationContext
|
|
from codeflash_python.pipeline._function_optimizer import (
|
|
NUMBA_REQUIRED_MODULES,
|
|
NUMERICAL_MODULES,
|
|
PythonFunctionOptimizer,
|
|
is_numerical_code,
|
|
resolve_function_ast,
|
|
)
|
|
|
|
|
|
class TestIsNumericalCode:
|
|
"""Tests for is_numerical_code detection."""
|
|
|
|
def test_torch_import(self) -> None:
|
|
"""Code importing torch is detected as numerical."""
|
|
code = textwrap.dedent("""\
|
|
import torch
|
|
|
|
def compute():
|
|
return torch.tensor([1, 2])
|
|
""")
|
|
assert is_numerical_code(code) is True
|
|
|
|
def test_no_numerical_imports(self) -> None:
|
|
"""Code without numerical imports returns False."""
|
|
code = textwrap.dedent("""\
|
|
import os
|
|
|
|
def compute():
|
|
return os.getcwd()
|
|
""")
|
|
assert is_numerical_code(code) is False
|
|
|
|
def test_with_function_name_torch(self) -> None:
|
|
"""When given a function name, only that function is checked."""
|
|
code = textwrap.dedent("""\
|
|
import torch
|
|
|
|
def uses_torch():
|
|
return torch.zeros(5)
|
|
|
|
def plain():
|
|
return 42
|
|
""")
|
|
assert is_numerical_code(code, function_name="uses_torch") is True
|
|
assert is_numerical_code(code, function_name="plain") is False
|
|
|
|
def test_numba_required_modules_without_numba(self) -> None:
|
|
"""numpy/scipy/math alone return False when numba is not installed."""
|
|
from codeflash_python.pipeline._function_optimizer import _HAS_NUMBA
|
|
|
|
code = "import numpy\ndef f(): return numpy.array([1])\n"
|
|
if not _HAS_NUMBA:
|
|
assert is_numerical_code(code) is False
|
|
else:
|
|
assert is_numerical_code(code) is True
|
|
|
|
def test_syntax_error(self) -> None:
|
|
"""Syntax errors in code return False."""
|
|
assert is_numerical_code("def foo(:") is False
|
|
|
|
def test_nonexistent_function(self) -> None:
|
|
"""A function name not in the code returns False."""
|
|
code = "import torch\ndef foo(): return 1\n"
|
|
assert is_numerical_code(code, function_name="nonexistent") is False
|
|
|
|
def test_method_in_class(self) -> None:
|
|
"""Class methods using numerical code are detected."""
|
|
code = textwrap.dedent("""\
|
|
import torch
|
|
|
|
class Model:
|
|
def forward(self):
|
|
return torch.zeros(3)
|
|
""")
|
|
assert is_numerical_code(code, function_name="Model.forward") is True
|
|
|
|
def test_constants(self) -> None:
|
|
"""Module constants are defined."""
|
|
assert "numpy" in NUMERICAL_MODULES
|
|
assert "torch" in NUMERICAL_MODULES
|
|
assert "math" in NUMBA_REQUIRED_MODULES
|
|
|
|
|
|
class TestResolveFunctionAst:
|
|
"""Tests for resolve_function_ast."""
|
|
|
|
def test_top_level(self) -> None:
|
|
"""Top-level function is resolved."""
|
|
code = "def foo():\n return 1\n"
|
|
result = resolve_function_ast(code, "foo", [])
|
|
assert result is not None
|
|
assert result.name == "foo"
|
|
|
|
def test_method(self) -> None:
|
|
"""Class method is resolved via parent chain."""
|
|
code = textwrap.dedent("""\
|
|
class MyClass:
|
|
def run(self):
|
|
return 42
|
|
""")
|
|
parents = [FunctionParent(name="MyClass", type="ClassDef")]
|
|
result = resolve_function_ast(code, "run", parents)
|
|
assert result is not None
|
|
assert result.name == "run"
|
|
|
|
def test_missing_returns_none(self) -> None:
|
|
"""Missing function returns None."""
|
|
code = "x = 1\n"
|
|
result = resolve_function_ast(code, "nope", [])
|
|
assert result is None
|
|
|
|
|
|
class TestNoGenTests:
|
|
"""Tests for --no-gen-tests flag wiring."""
|
|
|
|
def test_field_defaults_to_false(self) -> None:
|
|
"""no_gen_tests defaults to False."""
|
|
ctx = OptimizationContext(
|
|
plugin=MagicMock(),
|
|
project_root=MagicMock(),
|
|
test_cfg=MagicMock(),
|
|
ai_client=MagicMock(),
|
|
)
|
|
opt = PythonFunctionOptimizer(ctx=ctx)
|
|
assert opt.no_gen_tests is False
|
|
|
|
def test_field_accepts_true(self) -> None:
|
|
"""no_gen_tests=True is stored on the instance."""
|
|
ctx = OptimizationContext(
|
|
plugin=MagicMock(),
|
|
project_root=MagicMock(),
|
|
test_cfg=MagicMock(),
|
|
ai_client=MagicMock(),
|
|
)
|
|
opt = PythonFunctionOptimizer(ctx=ctx, no_gen_tests=True)
|
|
assert opt.no_gen_tests is True
|
|
|
|
async def test_skips_ai_test_generation(self) -> None:
|
|
"""When no_gen_tests=True, generate_ai_tests is never called."""
|
|
_mod = "codeflash_python.pipeline._function_optimizer"
|
|
_cls = f"{_mod}.PythonFunctionOptimizer"
|
|
|
|
fn_input = MagicMock()
|
|
fn_input.function.qualified_name = "mod.func"
|
|
fn_input.function.parents = []
|
|
fn_input.function.function_name = "func"
|
|
fn_input.function.is_async = False
|
|
|
|
with (
|
|
patch(
|
|
"codeflash_python.pipeline._test_orchestrator.generate_ai_tests"
|
|
) as mock_gen,
|
|
patch(
|
|
"codeflash_python.pipeline._test_orchestrator.instrument_tests_for_function",
|
|
return_value=None,
|
|
),
|
|
patch(
|
|
"codeflash_python.pipeline._test_orchestrator.generate_concolic_tests",
|
|
return_value=({}, "", None),
|
|
),
|
|
patch(
|
|
"codeflash_python.context.pipeline.get_code_optimization_context",
|
|
return_value=MagicMock(),
|
|
),
|
|
patch(
|
|
"codeflash_python.pipeline._module_prep.resolve_python_function_ast",
|
|
return_value=None,
|
|
),
|
|
patch(f"{_mod}.is_numerical_code", return_value=False),
|
|
patch(
|
|
"codeflash_python.verification._baseline.establish_original_code_baseline"
|
|
),
|
|
):
|
|
ctx = OptimizationContext(
|
|
plugin=MagicMock(),
|
|
project_root=MagicMock(),
|
|
test_cfg=MagicMock(),
|
|
ai_client=MagicMock(),
|
|
)
|
|
opt = PythonFunctionOptimizer(
|
|
ctx=ctx,
|
|
no_gen_tests=True,
|
|
)
|
|
# optimize() will exit early at the baseline step since
|
|
# test_files is None, but the generate_ai_tests guard
|
|
# is checked before that.
|
|
await opt.optimize(fn_input)
|
|
|
|
mock_gen.assert_not_called()
|