2025-04-01 20:51:27 +00:00
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
2025-04-01 23:18:21 +00:00
|
|
|
from tempfile import TemporaryDirectory
|
2025-04-01 20:51:27 +00:00
|
|
|
|
|
|
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
2026-03-02 20:25:50 +00:00
|
|
|
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
2026-03-04 06:52:32 +00:00
|
|
|
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
|
|
|
|
from codeflash.models.models import CodeOptimizationContext
|
2025-04-01 20:51:27 +00:00
|
|
|
from codeflash.verification.verification_utils import TestConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_add_decorator_imports_helper_in_class():
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_classmethod.py").resolve()
|
|
|
|
|
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
run_cwd = Path(__file__).parent.parent.resolve()
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
2026-03-02 20:25:50 +00:00
|
|
|
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
2025-04-01 20:51:27 +00:00
|
|
|
os.chdir(run_cwd)
|
2026-01-29 09:39:48 +00:00
|
|
|
# func_optimizer = pass
|
2025-04-01 20:51:27 +00:00
|
|
|
try:
|
|
|
|
|
ctx_result = func_optimizer.get_code_optimization_context()
|
|
|
|
|
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
2026-01-29 09:39:48 +00:00
|
|
|
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
2025-04-01 20:51:27 +00:00
|
|
|
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
2025-09-28 07:43:33 +00:00
|
|
|
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
2025-04-01 20:51:27 +00:00
|
|
|
|
|
|
|
|
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def sort_classmethod(x):
|
|
|
|
|
y = BubbleSortClass()
|
|
|
|
|
return y.sorter(x)
|
|
|
|
|
"""
|
|
|
|
|
expected_code_helper = """from line_profiler import profile as codeflash_line_profile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hi():
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BubbleSortClass:
|
2026-01-01 17:34:40 +00:00
|
|
|
@codeflash_line_profile
|
2025-04-01 20:51:27 +00:00
|
|
|
def __init__(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def sorter(self, arr):
|
|
|
|
|
n = len(arr)
|
|
|
|
|
for i in range(n):
|
|
|
|
|
for j in range(0, n - i - 1):
|
|
|
|
|
if arr[j] > arr[j + 1]:
|
|
|
|
|
arr[j], arr[j + 1] = arr[j + 1], arr[j]
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
def helper(self, arr, j):
|
|
|
|
|
return arr[j] > arr[j + 1]
|
|
|
|
|
"""
|
|
|
|
|
assert code_path.read_text("utf-8") == expected_code_main
|
|
|
|
|
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper
|
|
|
|
|
finally:
|
|
|
|
|
func_optimizer.write_code_and_helpers(
|
2026-01-29 09:39:48 +00:00
|
|
|
func_optimizer.function_to_optimize_source_code,
|
|
|
|
|
original_helper_code,
|
|
|
|
|
func_optimizer.function_to_optimize.file_path,
|
2025-04-01 20:51:27 +00:00
|
|
|
)
|
|
|
|
|
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2025-04-01 20:51:27 +00:00
|
|
|
def test_add_decorator_imports_helper_in_nested_class():
|
2026-01-29 09:39:48 +00:00
|
|
|
# Need to invert the assert once the helper detection is fixed
|
2025-04-01 20:51:27 +00:00
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_nested_classmethod.py").resolve()
|
|
|
|
|
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
run_cwd = Path(__file__).parent.parent.resolve()
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
2026-03-02 20:25:50 +00:00
|
|
|
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
2025-04-01 20:51:27 +00:00
|
|
|
os.chdir(run_cwd)
|
2026-01-29 09:39:48 +00:00
|
|
|
# func_optimizer = pass
|
2025-04-01 20:51:27 +00:00
|
|
|
try:
|
|
|
|
|
ctx_result = func_optimizer.get_code_optimization_context()
|
|
|
|
|
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
2026-01-29 09:39:48 +00:00
|
|
|
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
2025-04-01 20:51:27 +00:00
|
|
|
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
2025-09-28 07:43:33 +00:00
|
|
|
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
2025-04-01 20:51:27 +00:00
|
|
|
|
|
|
|
|
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def sort_classmethod(x):
|
|
|
|
|
y = WrapperClass.BubbleSortClass()
|
|
|
|
|
return y.sorter(x)
|
|
|
|
|
"""
|
|
|
|
|
assert code_path.read_text("utf-8") == expected_code_main
|
2026-01-01 17:34:40 +00:00
|
|
|
# WrapperClass.__init__ is now detected as a helper since WrapperClass.BubbleSortClass() instantiates it
|
|
|
|
|
assert len(code_context.helper_functions) == 1
|
|
|
|
|
assert code_context.helper_functions[0].qualified_name == "WrapperClass.__init__"
|
2025-04-01 20:51:27 +00:00
|
|
|
finally:
|
|
|
|
|
func_optimizer.write_code_and_helpers(
|
2026-01-29 09:39:48 +00:00
|
|
|
func_optimizer.function_to_optimize_source_code,
|
|
|
|
|
original_helper_code,
|
|
|
|
|
func_optimizer.function_to_optimize.file_path,
|
2025-04-01 20:51:27 +00:00
|
|
|
)
|
2025-04-01 23:18:21 +00:00
|
|
|
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2025-04-01 23:18:21 +00:00
|
|
|
def test_add_decorator_imports_nodeps():
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
|
|
|
|
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
run_cwd = Path(__file__).parent.parent.resolve()
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
2026-03-02 20:25:50 +00:00
|
|
|
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
2025-04-01 23:18:21 +00:00
|
|
|
os.chdir(run_cwd)
|
2026-01-29 09:39:48 +00:00
|
|
|
# func_optimizer = pass
|
2025-04-01 23:18:21 +00:00
|
|
|
try:
|
|
|
|
|
ctx_result = func_optimizer.get_code_optimization_context()
|
|
|
|
|
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
2026-01-29 09:39:48 +00:00
|
|
|
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
2025-04-01 23:18:21 +00:00
|
|
|
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
2025-09-28 07:43:33 +00:00
|
|
|
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
2025-04-01 23:18:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def sorter(arr):
|
|
|
|
|
print("codeflash stdout: Sorting list")
|
|
|
|
|
for i in range(len(arr)):
|
|
|
|
|
for j in range(len(arr) - 1):
|
|
|
|
|
if arr[j] > arr[j + 1]:
|
|
|
|
|
temp = arr[j]
|
|
|
|
|
arr[j] = arr[j + 1]
|
|
|
|
|
arr[j + 1] = temp
|
|
|
|
|
print(f"result: {{arr}}")
|
|
|
|
|
return arr
|
|
|
|
|
"""
|
|
|
|
|
assert code_path.read_text("utf-8") == expected_code_main
|
|
|
|
|
finally:
|
|
|
|
|
func_optimizer.write_code_and_helpers(
|
2026-01-29 09:39:48 +00:00
|
|
|
func_optimizer.function_to_optimize_source_code,
|
|
|
|
|
original_helper_code,
|
|
|
|
|
func_optimizer.function_to_optimize.file_path,
|
2025-04-01 23:18:21 +00:00
|
|
|
)
|
|
|
|
|
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2025-04-01 23:18:21 +00:00
|
|
|
def test_add_decorator_imports_helper_outside():
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_deps.py").resolve()
|
|
|
|
|
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
run_cwd = Path(__file__).parent.parent.resolve()
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
|
2026-03-02 20:25:50 +00:00
|
|
|
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
2025-04-01 23:18:21 +00:00
|
|
|
os.chdir(run_cwd)
|
2026-01-29 09:39:48 +00:00
|
|
|
# func_optimizer = pass
|
2025-04-01 23:18:21 +00:00
|
|
|
try:
|
|
|
|
|
ctx_result = func_optimizer.get_code_optimization_context()
|
|
|
|
|
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
2026-01-29 09:39:48 +00:00
|
|
|
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
2025-04-01 23:18:21 +00:00
|
|
|
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
2025-09-28 07:43:33 +00:00
|
|
|
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
2025-04-01 23:18:21 +00:00
|
|
|
|
|
|
|
|
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
|
|
|
|
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def sorter_deps(arr):
|
|
|
|
|
for i in range(len(arr)):
|
|
|
|
|
for j in range(len(arr) - 1):
|
|
|
|
|
if dep1_comparer(arr, j):
|
|
|
|
|
dep2_swap(arr, j)
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
expected_code_helper1 = """from line_profiler import profile as codeflash_line_profile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def dep1_comparer(arr, j: int) -> bool:
|
|
|
|
|
return arr[j] > arr[j + 1]
|
|
|
|
|
"""
|
2026-01-29 09:39:48 +00:00
|
|
|
expected_code_helper2 = """from line_profiler import profile as codeflash_line_profile
|
2025-04-01 23:18:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def dep2_swap(arr, j):
|
|
|
|
|
temp = arr[j]
|
|
|
|
|
arr[j] = arr[j + 1]
|
|
|
|
|
arr[j + 1] = temp
|
|
|
|
|
"""
|
|
|
|
|
assert code_path.read_text("utf-8") == expected_code_main
|
|
|
|
|
assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper1
|
|
|
|
|
assert code_context.helper_functions[1].file_path.read_text("utf-8") == expected_code_helper2
|
|
|
|
|
finally:
|
|
|
|
|
func_optimizer.write_code_and_helpers(
|
2026-01-29 09:39:48 +00:00
|
|
|
func_optimizer.function_to_optimize_source_code,
|
|
|
|
|
original_helper_code,
|
|
|
|
|
func_optimizer.function_to_optimize.file_path,
|
2025-04-01 23:18:21 +00:00
|
|
|
)
|
|
|
|
|
|
2026-01-29 09:39:48 +00:00
|
|
|
|
2025-04-01 23:18:21 +00:00
|
|
|
def test_add_decorator_imports_helper_in_dunder_class():
|
|
|
|
|
code_str = """def sorter(arr):
|
|
|
|
|
ans = helper(arr)
|
|
|
|
|
return ans
|
|
|
|
|
class helper:
|
|
|
|
|
def __init__(self, arr):
|
|
|
|
|
return arr.sort()"""
|
|
|
|
|
code_path = TemporaryDirectory()
|
|
|
|
|
code_write_path = Path(code_path.name) / "dunder_class.py"
|
2026-01-29 09:39:48 +00:00
|
|
|
code_write_path.write_text(code_str, "utf-8")
|
2025-04-01 23:18:21 +00:00
|
|
|
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
|
|
|
|
project_root_path = Path(code_path.name)
|
|
|
|
|
run_cwd = Path(__file__).parent.parent.resolve()
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
|
2026-03-02 20:25:50 +00:00
|
|
|
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
2025-04-01 23:18:21 +00:00
|
|
|
os.chdir(run_cwd)
|
2026-01-29 09:39:48 +00:00
|
|
|
# func_optimizer = pass
|
2025-04-01 23:18:21 +00:00
|
|
|
try:
|
|
|
|
|
ctx_result = func_optimizer.get_code_optimization_context()
|
|
|
|
|
code_context: CodeOptimizationContext = ctx_result.unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
2026-01-29 09:39:48 +00:00
|
|
|
line_profiler_output_file = add_decorator_imports(func_optimizer.function_to_optimize, code_context)
|
2025-04-01 23:18:21 +00:00
|
|
|
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
2025-09-28 07:43:33 +00:00
|
|
|
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
2025-04-01 23:18:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@codeflash_line_profile
|
|
|
|
|
def sorter(arr):
|
|
|
|
|
ans = helper(arr)
|
|
|
|
|
return ans
|
|
|
|
|
class helper:
|
2026-01-01 17:34:40 +00:00
|
|
|
@codeflash_line_profile
|
2025-04-01 23:18:21 +00:00
|
|
|
def __init__(self, arr):
|
|
|
|
|
return arr.sort()
|
|
|
|
|
"""
|
|
|
|
|
assert code_write_path.read_text("utf-8") == expected_code_main
|
|
|
|
|
finally:
|
|
|
|
|
pass
|
2026-01-14 01:34:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# ============================================================================
|
|
|
|
|
# Tests for contains_jit_decorator
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorNumba:
|
|
|
|
|
"""Tests for numba JIT decorator detection."""
|
|
|
|
|
|
|
|
|
|
def test_numba_jit_with_module_prefix(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
@numba.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_jit_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba as nb
|
|
|
|
|
|
|
|
|
|
@nb.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_jit_direct_import(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import jit
|
|
|
|
|
|
|
|
|
|
@jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_jit_direct_import_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import jit as my_jit
|
|
|
|
|
|
|
|
|
|
@my_jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_jit_with_arguments(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
@numba.jit(nopython=True)
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_jit_direct_import_with_arguments(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import jit
|
|
|
|
|
|
|
|
|
|
@jit(nopython=True, cache=True)
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_njit(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import njit
|
|
|
|
|
|
|
|
|
|
@njit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_njit_with_module_prefix(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
@numba.njit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_vectorize(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import vectorize
|
|
|
|
|
|
|
|
|
|
@vectorize
|
|
|
|
|
def my_func(x):
|
|
|
|
|
return x * 2
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_guvectorize(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
@numba.guvectorize(['void(float64[:], float64[:])'], '(n)->(n)')
|
|
|
|
|
def my_func(x, res):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_stencil(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import stencil
|
|
|
|
|
|
|
|
|
|
@stencil
|
|
|
|
|
def my_kernel(a):
|
|
|
|
|
return a[0, 0] + a[0, 1]
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_cfunc(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import cfunc
|
|
|
|
|
|
|
|
|
|
@cfunc("float64(float64)")
|
|
|
|
|
def my_func(x):
|
|
|
|
|
return x * 2
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_generated_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import generated_jit
|
|
|
|
|
|
|
|
|
|
@generated_jit
|
|
|
|
|
def my_func(x):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_cuda_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
@numba.cuda.jit
|
|
|
|
|
def my_kernel():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_cuda_jit_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba as nb
|
|
|
|
|
|
|
|
|
|
@nb.cuda.jit
|
|
|
|
|
def my_kernel():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorTorch:
|
|
|
|
|
"""Tests for torch JIT decorator detection."""
|
|
|
|
|
|
|
|
|
|
def test_torch_compile(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@torch.compile
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_compile_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch as th
|
|
|
|
|
|
|
|
|
|
@th.compile
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_compile_direct_import(self):
|
|
|
|
|
code = """
|
|
|
|
|
from torch import compile
|
|
|
|
|
|
|
|
|
|
@compile
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_compile_with_arguments(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@torch.compile(mode="reduce-overhead")
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_jit_script(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_jit_script_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch as th
|
|
|
|
|
|
|
|
|
|
@th.jit.script
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_jit_trace(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
@torch.jit.trace
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_jit_imported_then_script(self):
|
|
|
|
|
code = """
|
|
|
|
|
from torch import jit
|
|
|
|
|
|
|
|
|
|
@jit.script
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_torch_jit_imported_then_trace(self):
|
|
|
|
|
code = """
|
|
|
|
|
from torch import jit
|
|
|
|
|
|
|
|
|
|
@jit.trace
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorTensorFlow:
|
|
|
|
|
"""Tests for TensorFlow JIT decorator detection."""
|
|
|
|
|
|
|
|
|
|
def test_tensorflow_function_with_tf_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
@tf.function
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_tensorflow_function_full_name(self):
|
|
|
|
|
code = """
|
|
|
|
|
import tensorflow
|
|
|
|
|
|
|
|
|
|
@tensorflow.function
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_tensorflow_function_direct_import(self):
|
|
|
|
|
code = """
|
|
|
|
|
from tensorflow import function
|
|
|
|
|
|
|
|
|
|
@function
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_tensorflow_function_with_arguments(self):
|
|
|
|
|
code = """
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
@tf.function(jit_compile=True)
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_tf_function_direct_import_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
from tensorflow import function as tf_func
|
|
|
|
|
|
|
|
|
|
@tf_func
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorJax:
|
|
|
|
|
"""Tests for JAX JIT decorator detection."""
|
|
|
|
|
|
|
|
|
|
def test_jax_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import jax
|
|
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jax_jit_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
import jax as j
|
|
|
|
|
|
|
|
|
|
@j.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jax_jit_direct_import(self):
|
|
|
|
|
code = """
|
|
|
|
|
from jax import jit
|
|
|
|
|
|
|
|
|
|
@jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jax_jit_direct_import_with_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
from jax import jit as jax_jit
|
|
|
|
|
|
|
|
|
|
@jax_jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jax_jit_with_arguments(self):
|
|
|
|
|
code = """
|
|
|
|
|
import jax
|
|
|
|
|
|
|
|
|
|
@jax.jit(static_argnums=(0,))
|
|
|
|
|
def my_func(x, y):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorNegativeCases:
|
|
|
|
|
"""Tests that should NOT detect JIT decorators."""
|
|
|
|
|
|
|
|
|
|
def test_no_decorators(self):
|
|
|
|
|
code = """
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_other_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
import functools
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_custom_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
def my_decorator(func):
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
@my_decorator
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_property_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
@property
|
|
|
|
|
def my_prop(self):
|
|
|
|
|
return self._value
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_staticmethod_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
@staticmethod
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_classmethod_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
class MyClass:
|
|
|
|
|
@classmethod
|
|
|
|
|
def my_func(cls):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jit_in_comment(self):
|
|
|
|
|
code = """
|
|
|
|
|
# @numba.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jit_in_string(self):
|
|
|
|
|
code = '''
|
|
|
|
|
def my_func():
|
|
|
|
|
"""This function could use @numba.jit decorator."""
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_unrelated_jit_name(self):
|
|
|
|
|
code = """
|
|
|
|
|
def jit():
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_unrelated_module_with_jit_attribute(self):
|
|
|
|
|
code = """
|
|
|
|
|
import my_module
|
|
|
|
|
|
|
|
|
|
@my_module.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_numba_import_but_no_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_jit_variable_not_decorator(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import jit
|
|
|
|
|
|
|
|
|
|
def my_func():
|
|
|
|
|
x = jit
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorEdgeCases:
|
|
|
|
|
"""Edge case tests for JIT decorator detection."""
|
|
|
|
|
|
|
|
|
|
def test_multiple_decorators_with_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
import functools
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache
|
|
|
|
|
@numba.jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_multiple_decorators_jit_first(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
import functools
|
|
|
|
|
|
|
|
|
|
@numba.jit
|
|
|
|
|
@functools.lru_cache
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_async_function_with_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
@numba.jit
|
|
|
|
|
async def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
2026-01-14 02:30:45 +00:00
|
|
|
assert contains_jit_decorator(code) is False
|
2026-01-14 01:34:13 +00:00
|
|
|
|
|
|
|
|
def test_method_in_class_with_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
class MyClass:
|
|
|
|
|
@numba.jit
|
|
|
|
|
def my_method(self):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_nested_class_method_with_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
class Outer:
|
|
|
|
|
class Inner:
|
|
|
|
|
@numba.jit
|
|
|
|
|
def my_method(self):
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_multiple_functions_one_with_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
def func_a():
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@numba.jit
|
|
|
|
|
def func_b():
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def func_c():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_multiple_jit_functions(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
import jax
|
|
|
|
|
|
|
|
|
|
@numba.jit
|
|
|
|
|
def func_a():
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@jax.jit
|
|
|
|
|
def func_b():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_empty_code(self):
|
|
|
|
|
code = ""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_syntax_error_code(self):
|
|
|
|
|
code = """
|
|
|
|
|
def func(
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_whitespace_only(self):
|
|
|
|
|
code = " \n\n \t\t\n"
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_only_imports(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
from jax import jit
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_lambda_cannot_have_decorator(self):
|
|
|
|
|
# Lambdas cannot have decorators in Python
|
|
|
|
|
code = """
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
f = lambda x: x * 2
|
|
|
|
|
"""
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_mixed_imports_and_aliases(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numba as nb
|
|
|
|
|
from torch import compile as torch_compile
|
|
|
|
|
import jax
|
|
|
|
|
|
|
|
|
|
@nb.jit
|
|
|
|
|
def func_a():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_decorator_in_different_module_context(self):
|
|
|
|
|
code = """
|
|
|
|
|
# Import numba for numeric computation
|
|
|
|
|
import numba
|
|
|
|
|
|
|
|
|
|
# Some other code
|
|
|
|
|
x = 5
|
|
|
|
|
|
|
|
|
|
class Processor:
|
|
|
|
|
@numba.njit
|
|
|
|
|
def process(self, data):
|
|
|
|
|
return data * 2
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_from_import_star_not_tracked(self):
|
|
|
|
|
# Star imports are not tracked, so @jit won't be detected
|
|
|
|
|
code = """
|
|
|
|
|
from numba import *
|
|
|
|
|
|
|
|
|
|
@jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
# Star imports are not tracked, so this returns False
|
|
|
|
|
assert not contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_multiple_from_imports_same_module(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import jit
|
|
|
|
|
from numba import njit
|
|
|
|
|
|
|
|
|
|
@njit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_reimport_with_different_alias(self):
|
|
|
|
|
code = """
|
|
|
|
|
from numba import jit
|
|
|
|
|
from numba import jit as fast_jit
|
|
|
|
|
|
|
|
|
|
@fast_jit
|
|
|
|
|
def my_func():
|
|
|
|
|
pass
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestContainsJitDecoratorComplexCases:
|
|
|
|
|
"""Complex real-world scenarios for JIT decorator detection."""
|
|
|
|
|
|
|
|
|
|
def test_realistic_numba_code(self):
|
|
|
|
|
code = """
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numba import jit, prange
|
|
|
|
|
|
|
|
|
|
@jit(nopython=True, parallel=True)
|
|
|
|
|
def compute_sum(arr):
|
|
|
|
|
total = 0.0
|
|
|
|
|
for i in prange(len(arr)):
|
|
|
|
|
total += arr[i]
|
|
|
|
|
return total
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
data = np.random.rand(1000000)
|
|
|
|
|
result = compute_sum(data)
|
|
|
|
|
print(result)
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_realistic_torch_code(self):
|
|
|
|
|
code = """
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
class MyModel(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.linear = nn.Linear(10, 5)
|
|
|
|
|
|
|
|
|
|
@torch.compile
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.linear(x)
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_realistic_jax_code(self):
|
|
|
|
|
code = """
|
|
|
|
|
import jax
|
|
|
|
|
import jax.numpy as jnp
|
|
|
|
|
from jax import jit, grad
|
|
|
|
|
|
|
|
|
|
@jit
|
|
|
|
|
def loss_fn(params, x, y):
|
|
|
|
|
pred = jnp.dot(x, params)
|
|
|
|
|
return jnp.mean((pred - y) ** 2)
|
|
|
|
|
|
|
|
|
|
grad_fn = grad(loss_fn)
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_realistic_tensorflow_code(self):
|
|
|
|
|
code = """
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
|
|
@tf.function(jit_compile=True)
|
|
|
|
|
def train_step(model, x, y):
|
|
|
|
|
with tf.GradientTape() as tape:
|
|
|
|
|
predictions = model(x)
|
|
|
|
|
loss = tf.reduce_mean(tf.square(predictions - y))
|
|
|
|
|
gradients = tape.gradient(loss, model.trainable_variables)
|
|
|
|
|
return loss, gradients
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|
|
|
|
|
|
|
|
|
|
def test_file_with_many_functions_one_jit(self):
|
|
|
|
|
code = """
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import numpy as np
|
|
|
|
|
from numba import njit
|
|
|
|
|
|
|
|
|
|
def helper_a():
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
def helper_b():
|
|
|
|
|
return 2
|
|
|
|
|
|
|
|
|
|
class DataProcessor:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.data = []
|
|
|
|
|
|
|
|
|
|
def process(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@njit
|
|
|
|
|
def fast_compute(x, y):
|
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
result = fast_compute(1, 2)
|
|
|
|
|
print(result)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|
|
|
|
|
"""
|
|
|
|
|
assert contains_jit_decorator(code)
|