codeflash/tests/test_is_numerical_code.py

1014 lines
27 KiB
Python

"""Comprehensive unit tests for is_numerical_code function."""
from unittest.mock import patch
from codeflash.languages.python.static_analysis.code_extractor import is_numerical_code
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestBasicNumpyUsage:
"""Test basic numpy library detection (with numba available)."""
def test_numpy_with_standard_alias(self):
code = """
import numpy as np
def process_data(x):
return np.sum(x)
"""
assert is_numerical_code(code, "process_data") is True
def test_numpy_without_alias(self):
code = """
import numpy
def process_data(x):
return numpy.array(x)
"""
assert is_numerical_code(code, "process_data") is True
def test_numpy_from_import(self):
code = """
from numpy import array, zeros
def create_array():
return array([1, 2, 3])
"""
assert is_numerical_code(code, "create_array") is True
def test_numpy_from_import_with_alias(self):
code = """
from numpy import array as arr
def create_array():
return arr([1, 2, 3])
"""
assert is_numerical_code(code, "create_array") is True
def test_numpy_custom_alias(self):
code = """
import numpy as custom_name
def func(x):
return custom_name.array(x)
"""
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestNumpySubmodules:
"""Test numpy submodule imports (with numba available)."""
def test_numpy_linalg_direct(self):
code = """
import numpy.linalg
def func(x):
return numpy.linalg.norm(x)
"""
assert is_numerical_code(code, "func") is True
def test_numpy_linalg_aliased(self):
code = """
import numpy.linalg as la
def func(x):
return la.norm(x)
"""
assert is_numerical_code(code, "func") is True
def test_numpy_random_aliased(self):
code = """
import numpy.random as rng
def func():
return rng.randint(0, 10)
"""
assert is_numerical_code(code, "func") is True
def test_from_numpy_import_submodule(self):
code = """
from numpy import linalg
def func(x):
return linalg.norm(x)
"""
assert is_numerical_code(code, "func") is True
def test_from_numpy_linalg_import_function(self):
code = """
from numpy.linalg import norm
def func(x):
return norm(x)
"""
assert is_numerical_code(code, "func") is True
class TestTorchUsage:
"""Test PyTorch library detection."""
def test_torch_basic(self):
code = """
import torch
def train_model(model):
return torch.nn.functional.relu(model)
"""
assert is_numerical_code(code, "train_model") is True
def test_torch_standard_alias(self):
code = """
import torch as th
def func(x):
return th.tensor(x)
"""
assert is_numerical_code(code, "func") is True
def test_torch_nn_alias(self):
code = """
import torch.nn as nn
def func():
return nn.Linear(10, 10)
"""
assert is_numerical_code(code, "func") is True
def test_torch_functional_alias(self):
code = """
import torch.nn.functional as F
def func(x):
return F.relu(x)
"""
assert is_numerical_code(code, "func") is True
def test_torch_from_import(self):
code = """
from torch.nn.functional import relu
def func(x):
return relu(x)
"""
assert is_numerical_code(code, "func") is True
def test_torch_from_import_aliased(self):
code = """
from torch.nn.functional import softmax as sm
def func(x):
return sm(x)
"""
assert is_numerical_code(code, "func") is True
def test_torch_utils_data(self):
code = """
import torch.utils.data as data
def func():
return data.DataLoader([])
"""
assert is_numerical_code(code, "func") is True
class TestTensorflowUsage:
"""Test TensorFlow library detection."""
def test_tensorflow_basic(self):
code = """
import tensorflow
def func():
return tensorflow.Variable(1)
"""
assert is_numerical_code(code, "func") is True
def test_tensorflow_standard_alias(self):
code = """
import tensorflow as tf
def build_model():
return tf.keras.Sequential()
"""
assert is_numerical_code(code, "build_model") is True
def test_tensorflow_keras_alias(self):
code = """
import tensorflow.keras as keras
def func():
return keras.Sequential()
"""
assert is_numerical_code(code, "func") is True
def test_tensorflow_keras_layers_alias(self):
code = """
import tensorflow.keras.layers as layers
def func():
return layers.Dense(10)
"""
assert is_numerical_code(code, "func") is True
def test_tensorflow_from_import(self):
code = """
from tensorflow import keras
def func():
return keras.Model()
"""
assert is_numerical_code(code, "func") is True
class TestJaxUsage:
"""Test JAX library detection."""
def test_jax_basic(self):
code = """
import jax
def func(x):
return jax.grad(x)
"""
assert is_numerical_code(code, "func") is True
def test_jax_numpy_alias(self):
code = """
import jax.numpy as jnp
def func(x):
return jnp.sum(x)
"""
assert is_numerical_code(code, "func") is True
def test_from_jax_import_numpy(self):
code = """
from jax import numpy as jnp
def func(x):
return jnp.array(x)
"""
assert is_numerical_code(code, "func") is True
def test_jax_from_import(self):
code = """
from jax import grad, jit
def func(f):
return grad(f)
"""
assert is_numerical_code(code, "func") is True
class TestNumbaUsage:
"""Test Numba library detection."""
def test_numba_jit_decorator(self):
code = """
from numba import jit
@jit
def fast_func(x):
return x * 2
"""
assert is_numerical_code(code, "fast_func") is True
def test_numba_cuda(self):
code = """
import numba.cuda as cuda
def func():
return cuda.device_array(10)
"""
assert is_numerical_code(code, "func") is True
def test_numba_basic(self):
code = """
import numba
@numba.njit
def func(x):
return x + 1
"""
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestScipyUsage:
"""Test SciPy library detection (with numba available)."""
def test_scipy_basic(self):
code = """
import scipy
def func(x):
return scipy.integrate.quad(x, 0, 1)
"""
assert is_numerical_code(code, "func") is True
def test_scipy_stats(self):
code = """
from scipy import stats
def analyze(data):
return stats.describe(data)
"""
assert is_numerical_code(code, "analyze") is True
def test_scipy_stats_from_import(self):
code = """
from scipy.stats import norm
def func(x):
return norm.pdf(x)
"""
assert is_numerical_code(code, "func") is True
def test_scipy_optimize_alias(self):
code = """
import scipy.optimize as opt
def func(f, x0):
return opt.minimize(f, x0)
"""
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestMathUsage:
"""Test math standard library detection (with numba available)."""
def test_math_basic(self):
code = """
import math
def calculate(x):
return math.sqrt(x)
"""
assert is_numerical_code(code, "calculate") is True
def test_math_from_import(self):
code = """
from math import sqrt, sin, cos
def calculate(x):
return sqrt(sin(x) ** 2 + cos(x) ** 2)
"""
assert is_numerical_code(code, "calculate") is True
def test_math_aliased(self):
code = """
import math as m
def calculate(x):
return m.pi * x
"""
assert is_numerical_code(code, "calculate") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestClassMethods:
"""Test detection in class methods, staticmethods, and classmethods (with numba available)."""
def test_regular_method_with_numpy(self):
code = """
import numpy as np
class DataProcessor:
def process(self, data):
return np.mean(data)
"""
assert is_numerical_code(code, "DataProcessor.process") is True
def test_regular_method_without_numerical(self):
code = """
import numpy as np
class DataProcessor:
def process(self, data):
return np.mean(data)
def other(self, x):
return x + 1
"""
assert is_numerical_code(code, "DataProcessor.other") is False
def test_staticmethod_with_numpy(self):
code = """
import numpy as np
class Calculator:
@staticmethod
def compute(x):
return np.dot(x, x)
"""
assert is_numerical_code(code, "Calculator.compute") is True
def test_classmethod_with_torch(self):
code = """
import torch
class Model:
@classmethod
def from_pretrained(cls, path):
return torch.load(path)
"""
assert is_numerical_code(code, "Model.from_pretrained") is True
def test_multiple_decorators(self):
code = """
import functools
import numpy as np
class MyClass:
@staticmethod
@functools.lru_cache
def cached_compute(x):
return np.sum(x)
"""
assert is_numerical_code(code, "MyClass.cached_compute") is True
class TestNoNumericalUsage:
"""Test that non-numerical code returns False."""
def test_simple_function(self):
code = """
def simple_func(x):
return x + 1
"""
assert is_numerical_code(code, "simple_func") is False
def test_string_manipulation(self):
code = """
def process_string(s):
return s.upper().strip()
"""
assert is_numerical_code(code, "process_string") is False
def test_list_operations(self):
code = """
def process_list(lst):
return [x * 2 for x in lst]
"""
assert is_numerical_code(code, "process_list") is False
def test_with_non_numerical_imports(self):
code = """
import os
import json
from pathlib import Path
def process_file(path):
return Path(path).read_text()
"""
assert is_numerical_code(code, "process_file") is False
def test_class_method_without_numerical(self):
code = """
class Helper:
def format(self, data):
return str(data)
"""
assert is_numerical_code(code, "Helper.format") is False
class TestFalsePositivePrevention:
"""Test that false positives are avoided."""
def test_function_named_numpy(self):
code = """
def numpy():
return 1
def func():
return numpy()
"""
assert is_numerical_code(code, "func") is False
def test_function_named_torch(self):
code = """
def torch():
return "fire"
def func():
return torch()
"""
assert is_numerical_code(code, "func") is False
def test_variable_named_np(self):
code = """
def func():
np = 5
return np + 1
"""
assert is_numerical_code(code, "func") is False
def test_class_named_math(self):
code = """
class math:
pass
def func():
return math()
"""
assert is_numerical_code(code, "func") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestEdgeCases:
"""Test edge cases and special scenarios (with numba available)."""
def test_nonexistent_function(self):
code = """
import numpy as np
def process_data(x):
return np.sum(x)
"""
assert is_numerical_code(code, "nonexistent") is False
def test_empty_function(self):
code = """
import numpy as np
def empty_func():
pass
"""
assert is_numerical_code(code, "empty_func") is False
def test_syntax_error_code(self):
code = """
def broken_func(
return 1
"""
assert is_numerical_code(code, "broken_func") is False
def test_empty_code_string(self):
assert is_numerical_code("", "func") is False
def test_type_annotation_with_numpy(self):
code = """
import numpy as np
def func(x: np.ndarray):
return x + 1
"""
assert is_numerical_code(code, "func") is True
def test_default_argument_with_numpy(self):
code = """
import numpy as np
def func(dtype=np.float32):
return dtype
"""
assert is_numerical_code(code, "func") is True
def test_numpy_in_docstring_only(self):
code = """
def func(x):
'''Uses numpy internally.'''
return x + 1
"""
assert is_numerical_code(code, "func") is False
def test_async_function_with_numpy(self):
code = """
import numpy as np
async def async_process(x):
return np.sum(x)
"""
assert is_numerical_code(code, "async_process") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestStarImports:
"""Test handling of star imports (with numba available).
Note: Star imports are difficult to track precisely since we'd need to
resolve what names are actually imported from the module. The current
implementation has limited support for star imports.
"""
def test_star_import_with_module_reference(self):
# Star imports are detected when the module name is still referenced
code = """
from numpy import *
import numpy
def func(x):
return numpy.array(x)
"""
assert is_numerical_code(code, "func") is True
def test_star_import_bare_name_not_detected(self):
# Bare names from star imports are not tracked (limitation)
code = """
from numpy import *
def func(x):
return array(x)
"""
# This is a known limitation - star import names aren't resolved
assert is_numerical_code(code, "func") is False
def test_star_import_math_bare_name_not_detected(self):
# Same limitation applies to math
code = """
from math import *
def func(x):
return sqrt(x)
"""
# Known limitation
assert is_numerical_code(code, "func") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestNestedUsage:
"""Test nested numerical library usage patterns (with numba available)."""
def test_numpy_in_lambda(self):
code = """
import numpy as np
def func():
f = lambda x: np.sum(x)
return f
"""
assert is_numerical_code(code, "func") is True
def test_numpy_in_list_comprehension(self):
code = """
import numpy as np
def func(arrays):
return [np.mean(arr) for arr in arrays]
"""
assert is_numerical_code(code, "func") is True
def test_numpy_in_conditional(self):
code = """
import numpy as np
def func(x, use_numpy=True):
if use_numpy:
return np.sum(x)
return sum(x)
"""
assert is_numerical_code(code, "func") is True
def test_numpy_in_try_except(self):
code = """
import numpy as np
def func(x):
try:
return np.sum(x)
except Exception:
return 0
"""
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestMultipleLibraries:
"""Test code using multiple numerical libraries (with numba available)."""
def test_numpy_and_torch(self):
code = """
import numpy as np
import torch
def func(x):
arr = np.array(x)
return torch.from_numpy(arr)
"""
assert is_numerical_code(code, "func") is True
def test_scipy_and_numpy(self):
code = """
import numpy as np
from scipy import stats
def analyze(data):
arr = np.array(data)
return stats.describe(arr)
"""
assert is_numerical_code(code, "analyze") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestQualifiedNames:
"""Test various qualified name patterns (with numba available)."""
def test_simple_function_name(self):
code = """
import numpy as np
def my_func():
return np.array([1])
"""
assert is_numerical_code(code, "my_func") is True
def test_class_dot_method(self):
code = """
import numpy as np
class MyClass:
def my_method(self):
return np.sum([1, 2])
"""
assert is_numerical_code(code, "MyClass.my_method") is True
def test_invalid_qualified_name_too_deep(self):
code = """
import numpy as np
class Outer:
class Inner:
def method(self):
return np.sum([1])
"""
# Nested classes are not supported
assert is_numerical_code(code, "Outer.Inner.method") is False
def test_method_in_wrong_class(self):
code = """
import numpy as np
class ClassA:
def method(self):
return np.sum([1])
class ClassB:
def method(self):
return 1
"""
assert is_numerical_code(code, "ClassA.method") is True
assert is_numerical_code(code, "ClassB.method") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
class TestEmptyFunctionName:
"""Test behavior when function_name is empty/None.
When function_name is not provided, the function should just check for the
presence of numerical imports without looking at a specific function body.
"""
def test_empty_string_with_numpy_import(self):
"""Empty function_name with numpy import should return True."""
code = """
import numpy as np
def some_func():
pass
"""
assert is_numerical_code(code, "") is True
def test_none_with_numpy_import(self):
"""None function_name with numpy import should return True."""
code = """
import numpy as np
def some_func():
pass
"""
assert is_numerical_code(code, None) is True
def test_empty_string_with_torch_import(self):
"""Empty function_name with torch import should return True."""
code = """
import torch
def some_func():
pass
"""
assert is_numerical_code(code, "") is True
def test_empty_string_with_multiple_numerical_imports(self):
"""Empty function_name with multiple numerical imports should return True."""
code = """
import numpy as np
import torch
from scipy import stats
def some_func():
pass
"""
assert is_numerical_code(code, "") is True
def test_empty_string_without_numerical_imports(self):
"""Empty function_name without numerical imports should return False."""
code = """
import os
import json
from pathlib import Path
def some_func():
pass
"""
assert is_numerical_code(code, "") is False
def test_none_without_numerical_imports(self):
"""None function_name without numerical imports should return False."""
code = """
import os
def some_func():
pass
"""
assert is_numerical_code(code, None) is False
def test_empty_string_with_jax_import(self):
"""Empty function_name with jax import should return True."""
code = """
import jax
import jax.numpy as jnp
"""
assert is_numerical_code(code, "") is True
def test_empty_string_with_tensorflow_import(self):
"""Empty function_name with tensorflow import should return True."""
code = """
import tensorflow as tf
"""
assert is_numerical_code(code, "") is True
def test_empty_string_with_math_import(self):
"""Empty function_name with math import should return True (numba available)."""
code = """
import math
def calculate(x):
return math.sqrt(x)
"""
assert is_numerical_code(code, "") is True
def test_empty_string_with_scipy_submodule(self):
"""Empty function_name with scipy submodule import should return True."""
code = """
from scipy.stats import norm
"""
assert is_numerical_code(code, "") is True
def test_empty_string_with_numba_import(self):
"""Empty function_name with numba import should return True."""
code = """
from numba import jit
"""
assert is_numerical_code(code, "") is True
def test_empty_code_with_empty_function_name(self):
"""Empty code with empty function_name should return False."""
assert is_numerical_code("", "") is False
def test_syntax_error_with_empty_function_name(self):
"""Syntax error code with empty function_name should return False."""
code = """
def broken(
import numpy
"""
assert is_numerical_code(code, "") is False
@patch("codeflash.code_utils.code_extractor.has_numba", False)
class TestEmptyFunctionNameWithoutNumba:
"""Test empty function_name behavior when numba is NOT available.
When numba is not installed, code using only math/numpy/scipy should return False,
since numba is required to optimize such code. Code using torch/jax/tensorflow/numba
should still return True.
"""
def test_empty_string_numpy_returns_false_without_numba(self):
"""Empty function_name with numpy should return False when numba unavailable."""
code = """
import numpy as np
def some_func():
pass
"""
assert is_numerical_code(code, "") is False
def test_empty_string_math_returns_false_without_numba(self):
"""Empty function_name with math should return False when numba unavailable."""
code = """
import math
"""
assert is_numerical_code(code, "") is False
def test_empty_string_scipy_returns_false_without_numba(self):
"""Empty function_name with scipy should return False when numba unavailable."""
code = """
from scipy import stats
"""
assert is_numerical_code(code, "") is False
def test_empty_string_torch_returns_true_without_numba(self):
"""Empty function_name with torch should return True even without numba."""
code = """
import torch
"""
assert is_numerical_code(code, "") is True
def test_empty_string_jax_returns_true_without_numba(self):
"""Empty function_name with jax should return True even without numba."""
code = """
import jax
"""
assert is_numerical_code(code, "") is True
def test_empty_string_tensorflow_returns_true_without_numba(self):
"""Empty function_name with tensorflow should return True even without numba."""
code = """
import tensorflow as tf
"""
assert is_numerical_code(code, "") is True
def test_empty_string_numba_import_returns_true_without_numba(self):
"""Empty function_name with numba import should return True."""
code = """
from numba import jit
"""
assert is_numerical_code(code, "") is True
def test_empty_string_numpy_and_torch_returns_true_without_numba(self):
"""Empty function_name with numpy+torch should return True (torch doesn't need numba)."""
code = """
import numpy as np
import torch
"""
# Returns True because torch is in modules_used and doesn't require numba
assert is_numerical_code(code, "") is True
def test_empty_string_math_and_scipy_returns_false_without_numba(self):
"""Empty function_name with only math+scipy should return False without numba."""
code = """
import math
from scipy import stats
"""
# Both math and scipy are in NUMBA_REQUIRED_MODULES
assert is_numerical_code(code, "") is False
@patch("codeflash.code_utils.code_extractor.has_numba", False)
class TestNumbaNotAvailable:
"""Test behavior when numba is NOT available in the environment.
When numba is not installed, code using only math/numpy/scipy should return False,
since numba is required to optimize such code. Code using torch/jax/tensorflow/numba
should still return True as these libraries don't require numba for optimization.
"""
def test_numpy_returns_false_without_numba(self):
"""Numpy usage should return False when numba is not available."""
code = """
import numpy as np
def process_data(x):
return np.sum(x)
"""
assert is_numerical_code(code, "process_data") is False
def test_scipy_returns_false_without_numba(self):
"""Scipy usage should return False when numba is not available."""
code = """
from scipy import stats
def analyze(data):
return stats.describe(data)
"""
assert is_numerical_code(code, "analyze") is False
def test_math_returns_false_without_numba(self):
"""Math usage should return False when numba is not available."""
code = """
import math
def calculate(x):
return math.sqrt(x)
"""
assert is_numerical_code(code, "calculate") is False
def test_torch_returns_true_without_numba(self):
"""Torch usage should return True even when numba is not available."""
code = """
import torch
def train_model(model):
return torch.nn.functional.relu(model)
"""
assert is_numerical_code(code, "train_model") is True
def test_jax_returns_true_without_numba(self):
"""JAX usage should return True even when numba is not available."""
code = """
import jax
def func(x):
return jax.grad(x)
"""
assert is_numerical_code(code, "func") is True
def test_tensorflow_returns_true_without_numba(self):
"""TensorFlow usage should return True even when numba is not available."""
code = """
import tensorflow as tf
def build_model():
return tf.keras.Sequential()
"""
assert is_numerical_code(code, "build_model") is True
def test_numba_import_returns_true_without_numba(self):
"""Code that imports numba should return True (numba is in modules_used)."""
code = """
from numba import jit
@jit
def fast_func(x):
return x * 2
"""
assert is_numerical_code(code, "fast_func") is True
def test_numpy_and_torch_returns_true_without_numba(self):
"""Mixed numpy+torch usage should return True since torch doesn't require numba."""
code = """
import numpy as np
import torch
def func(x):
arr = np.array(x)
return torch.from_numpy(arr)
"""
# Returns True because torch is in modules_used and torch doesn't require numba
assert is_numerical_code(code, "func") is True
def test_numpy_and_jax_returns_true_without_numba(self):
"""Mixed numpy+jax usage should return True since jax doesn't require numba."""
code = """
import numpy as np
import jax.numpy as jnp
def func(x):
arr = np.array(x)
return jnp.sum(arr)
"""
# Returns True because jax is in modules_used and jax doesn't require numba
assert is_numerical_code(code, "func") is True
def test_scipy_and_tensorflow_returns_true_without_numba(self):
"""Mixed scipy+tensorflow usage should return True since tensorflow doesn't require numba."""
code = """
from scipy import stats
import tensorflow as tf
def analyze_and_build(data):
result = stats.describe(data)
return tf.keras.Sequential()
"""
# Returns True because tensorflow is in modules_used and doesn't require numba
assert is_numerical_code(code, "analyze_and_build") is True
def test_numpy_submodule_returns_false_without_numba(self):
"""Numpy submodule usage should return False when numba is not available."""
code = """
import numpy.linalg as la
def func(x):
return la.norm(x)
"""
assert is_numerical_code(code, "func") is False
def test_math_from_import_returns_false_without_numba(self):
"""Math from import should return False when numba is not available."""
code = """
from math import sqrt, sin, cos
def calculate(x):
return sqrt(sin(x) ** 2 + cos(x) ** 2)
"""
assert is_numerical_code(code, "calculate") is False