codeflash/tests/test_inject_profiling_used_frameworks.py

1565 lines
66 KiB
Python
Raw Normal View History

2026-01-12 23:45:29 +00:00
"""Unit tests for inject_profiling_into_existing_test with different used_frameworks values.
These tests verify that the wrapper function is correctly generated with GPU device
synchronization code for different framework imports (torch, tensorflow, jax).
"""
from __future__ import annotations
2026-01-13 00:36:14 +00:00
import re
2026-01-12 23:45:29 +00:00
from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
detect_frameworks_from_code,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode
2026-01-13 00:36:14 +00:00
def normalize_instrumented_code(code: str) -> str:
"""Normalize instrumented code by replacing dynamic paths with placeholders.
This allows comparing instrumented code across test runs where temp paths differ.
2026-01-13 01:33:22 +00:00
Also normalizes f-string quoting differences between Python versions (Python 3.12+
allows single quotes inside single-quoted f-strings via PEP 701, but libcst
generates double-quoted f-strings for compatibility with older versions).
2026-01-13 00:36:14 +00:00
"""
2026-01-13 01:33:22 +00:00
# Normalize database path
code = re.sub(
2026-01-13 00:36:14 +00:00
r"sqlite3\.connect\(f'[^']+'",
"sqlite3.connect(f'{CODEFLASH_DB_PATH}'",
code
)
2026-01-13 01:33:22 +00:00
# Normalize f-string that contains the test_stdout_tag assignment
# This specific f-string has internal single quotes, so libcst uses double quotes
# on Python < 3.12, but single quotes on Python 3.12+
code = re.sub(
r'test_stdout_tag = f"([^"]+)"',
r"test_stdout_tag = f'\1'",
code
)
return code
2026-01-13 00:36:14 +00:00
# ============================================================================
# Expected instrumented code for BEHAVIOR mode
# ============================================================================
EXPECTED_NO_FRAMEWORKS_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_TORCH_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import torch
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize')
gc.disable()
try:
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_TORCH_ALIASED_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import torch as th
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = th.cuda.is_available() and th.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(th.backends, 'mps') and th.backends.mps.is_available() and hasattr(th.mps, 'synchronize')
gc.disable()
try:
if _codeflash_should_sync_cuda:
th.cuda.synchronize()
elif _codeflash_should_sync_mps:
th.mps.synchronize()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
th.cuda.synchronize()
elif _codeflash_should_sync_mps:
th.mps.synchronize()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_TORCH_SUBMODULE_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import torch
from mymodule import my_function
from torch import nn
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize')
gc.disable()
try:
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_TENSORFLOW_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import tensorflow
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_tf = hasattr(tensorflow.test.experimental, 'sync_devices')
gc.disable()
try:
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_TENSORFLOW_ALIASED_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import tensorflow as tf
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_tf = hasattr(tf.test.experimental, 'sync_devices')
gc.disable()
try:
if _codeflash_should_sync_tf:
tf.test.experimental.sync_devices()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_tf:
tf.test.experimental.sync_devices()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_JAX_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import jax
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_jax = hasattr(jax, 'block_until_ready')
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_jax:
jax.block_until_ready(return_value)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_JAX_ALIASED_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import jax as jnp
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_jax = hasattr(jnp, 'block_until_ready')
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_jax:
jnp.block_until_ready(return_value)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_TORCH_TENSORFLOW_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import tensorflow
import torch
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize')
_codeflash_should_sync_tf = hasattr(tensorflow.test.experimental, 'sync_devices')
gc.disable()
try:
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
EXPECTED_ALL_FRAMEWORKS_BEHAVIOR = """import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
import jax
import tensorflow
import torch
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize')
_codeflash_should_sync_jax = hasattr(jax, 'block_until_ready')
_codeflash_should_sync_tf = hasattr(tensorflow.test.experimental, 'sync_devices')
gc.disable()
try:
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
if _codeflash_should_sync_jax:
jax.block_until_ready(return_value)
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{CODEFLASH_DB_PATH}')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
_call__bound__arguments = inspect.signature(my_function).bind(1, 2)
_call__bound__arguments.apply_defaults()
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert result == 3
codeflash_con.close()
"""
# ============================================================================
# Expected instrumented code for PERFORMANCE mode
# ============================================================================
EXPECTED_NO_FRAMEWORKS_PERFORMANCE = """import gc
import os
import time
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}:{codeflash_duration}######!')
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2)
assert result == 3
"""
EXPECTED_TORCH_PERFORMANCE = """import gc
import os
import time
import torch
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize')
gc.disable()
try:
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}:{codeflash_duration}######!')
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2)
assert result == 3
"""
EXPECTED_TENSORFLOW_PERFORMANCE = """import gc
import os
import time
import tensorflow
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_tf = hasattr(tensorflow.test.experimental, 'sync_devices')
gc.disable()
try:
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}:{codeflash_duration}######!')
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2)
assert result == 3
"""
EXPECTED_JAX_PERFORMANCE = """import gc
import os
import time
import jax
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_jax = hasattr(jax, 'block_until_ready')
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_jax:
jax.block_until_ready(return_value)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}:{codeflash_duration}######!')
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2)
assert result == 3
"""
EXPECTED_ALL_FRAMEWORKS_PERFORMANCE = """import gc
import os
import time
import jax
import tensorflow
import torch
from mymodule import my_function
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs):
test_id = f'{codeflash_test_module_name}:{codeflash_test_class_name}:{codeflash_test_name}:{codeflash_line_id}:{codeflash_loop_index}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{codeflash_line_id}_{codeflash_test_index}'
2026-01-13 00:56:16 +00:00
test_stdout_tag = f'{codeflash_test_module_name}:{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}{codeflash_test_name}:{codeflash_function_name}:{codeflash_loop_index}:{invocation_id}'
2026-01-13 00:36:14 +00:00
print(f'!$######{test_stdout_tag}######$!')
exception = None
_codeflash_should_sync_cuda = torch.cuda.is_available() and torch.cuda.is_initialized()
_codeflash_should_sync_mps = not _codeflash_should_sync_cuda and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() and hasattr(torch.mps, 'synchronize')
_codeflash_should_sync_jax = hasattr(jax, 'block_until_ready')
_codeflash_should_sync_tf = hasattr(tensorflow.test.experimental, 'sync_devices')
gc.disable()
try:
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
if _codeflash_should_sync_cuda:
torch.cuda.synchronize()
elif _codeflash_should_sync_mps:
torch.mps.synchronize()
if _codeflash_should_sync_jax:
jax.block_until_ready(return_value)
if _codeflash_should_sync_tf:
tensorflow.test.experimental.sync_devices()
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{test_stdout_tag}:{codeflash_duration}######!')
if exception:
raise exception
return return_value
def test_my_function():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
result = codeflash_wrap(my_function, 'test_example', None, 'test_my_function', 'my_function', '0', codeflash_loop_index, 1, 2)
assert result == 3
"""
2026-01-12 23:45:29 +00:00
# ============================================================================
# Tests for detect_frameworks_from_code
# ============================================================================
class TestDetectFrameworksFromCode:
"""Tests for the detect_frameworks_from_code helper function."""
def test_no_frameworks(self) -> None:
"""Test detection with no GPU framework imports."""
code = """import os
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_import(self) -> None:
"""Test detection with torch import."""
code = """import torch
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"torch": "torch"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_aliased_import(self) -> None:
"""Test detection with torch imported as alias."""
code = """import torch as th
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"torch": "th"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_submodule_import(self) -> None:
"""Test detection with torch submodule import (from torch import nn)."""
code = """from torch import nn
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"torch": "torch"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_dotted_import(self) -> None:
"""Test detection with torch.cuda or torch.nn import."""
code = """import torch.cuda
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"torch": "torch"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_tensorflow_import(self) -> None:
"""Test detection with tensorflow import."""
code = """import tensorflow
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"tensorflow": "tensorflow"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_tensorflow_aliased_import(self) -> None:
"""Test detection with tensorflow imported as alias."""
code = """import tensorflow as tf
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"tensorflow": "tf"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_tensorflow_submodule_import(self) -> None:
"""Test detection with tensorflow submodule import."""
code = """from tensorflow import keras
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"tensorflow": "tensorflow"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_jax_import(self) -> None:
"""Test detection with jax import."""
code = """import jax
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"jax": "jax"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_jax_aliased_import(self) -> None:
"""Test detection with jax imported as alias."""
code = """import jax as jnp
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"jax": "jnp"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_jax_submodule_import(self) -> None:
"""Test detection with jax submodule import."""
code = """from jax import numpy as jnp
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"jax": "jax"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_multiple_frameworks(self) -> None:
"""Test detection with multiple framework imports."""
code = """import torch
import tensorflow
import jax
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"torch": "torch", "tensorflow": "tensorflow", "jax": "jax"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_multiple_frameworks_aliased(self) -> None:
"""Test detection with multiple aliased framework imports."""
code = """import torch as th
import tensorflow as tf
import jax as jnp
from mymodule import my_function
def test_something():
pass
"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {"torch": "th", "tensorflow": "tf", "jax": "jnp"}
assert result == expected
2026-01-12 23:45:29 +00:00
def test_syntax_error_returns_empty(self) -> None:
"""Test that syntax errors return empty dict."""
code = """this is not valid python code !!!"""
result = detect_frameworks_from_code(code)
2026-01-13 00:36:14 +00:00
expected = {}
assert result == expected
2026-01-12 23:45:29 +00:00
# ============================================================================
2026-01-13 00:36:14 +00:00
# Tests for inject_profiling_into_existing_test - BEHAVIOR mode
2026-01-12 23:45:29 +00:00
# ============================================================================
2026-01-13 00:36:14 +00:00
class TestInjectProfilingBehaviorMode:
"""Tests for inject_profiling_into_existing_test in BEHAVIOR mode."""
2026-01-12 23:45:29 +00:00
def test_no_frameworks_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with no GPU framework imports in BEHAVIOR mode."""
code = """from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(4, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_NO_FRAMEWORKS_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with PyTorch import in BEHAVIOR mode."""
code = """import torch
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TORCH_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_aliased_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with PyTorch imported as alias in BEHAVIOR mode."""
code = """import torch as th
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TORCH_ALIASED_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_submodule_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with PyTorch submodule import in BEHAVIOR mode."""
code = """from torch import nn
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TORCH_SUBMODULE_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_tensorflow_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with TensorFlow import in BEHAVIOR mode."""
code = """import tensorflow
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TENSORFLOW_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_tensorflow_aliased_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with TensorFlow imported as alias in BEHAVIOR mode."""
code = """import tensorflow as tf
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TENSORFLOW_ALIASED_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_jax_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with JAX import in BEHAVIOR mode."""
code = """import jax
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_JAX_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_jax_aliased_import_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with JAX imported as alias in BEHAVIOR mode."""
code = """import jax as jnp
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_JAX_ALIASED_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_torch_and_tensorflow_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with both PyTorch and TensorFlow imports in BEHAVIOR mode."""
code = """import torch
import tensorflow
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(6, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TORCH_TENSORFLOW_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
def test_all_three_frameworks_behavior_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with PyTorch, TensorFlow, and JAX imports in BEHAVIOR mode."""
code = """import torch
import tensorflow
import jax
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(7, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.BEHAVIOR,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_ALL_FRAMEWORKS_BEHAVIOR
assert result == expected
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
# ============================================================================
# Tests for inject_profiling_into_existing_test - PERFORMANCE mode
# ============================================================================
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
class TestInjectProfilingPerformanceMode:
"""Tests for inject_profiling_into_existing_test in PERFORMANCE mode."""
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
def test_no_frameworks_performance_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with no GPU framework imports in PERFORMANCE mode."""
code = """from mymodule import my_function
2026-01-12 23:45:29 +00:00
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
2026-01-13 00:36:14 +00:00
call_positions=[CodePosition(4, 13)],
2026-01-12 23:45:29 +00:00
function_to_optimize=func,
tests_project_root=tmp_path,
mode=TestingMode.PERFORMANCE,
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_NO_FRAMEWORKS_PERFORMANCE
assert result == expected
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
def test_torch_import_performance_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with PyTorch import in PERFORMANCE mode."""
code = """import torch
2026-01-12 23:45:29 +00:00
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
2026-01-13 00:36:14 +00:00
mode=TestingMode.PERFORMANCE,
2026-01-12 23:45:29 +00:00
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TORCH_PERFORMANCE
assert result == expected
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
def test_tensorflow_import_performance_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with TensorFlow import in PERFORMANCE mode."""
code = """import tensorflow
2026-01-12 23:45:29 +00:00
from mymodule import my_function
2026-01-13 00:36:14 +00:00
def test_my_function():
result = my_function(1, 2)
assert result == 3
2026-01-12 23:45:29 +00:00
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
2026-01-13 00:36:14 +00:00
call_positions=[CodePosition(5, 13)],
2026-01-12 23:45:29 +00:00
function_to_optimize=func,
tests_project_root=tmp_path,
2026-01-13 00:36:14 +00:00
mode=TestingMode.PERFORMANCE,
2026-01-12 23:45:29 +00:00
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_TENSORFLOW_PERFORMANCE
assert result == expected
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
def test_jax_import_performance_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with JAX import in PERFORMANCE mode."""
2026-01-12 23:45:29 +00:00
code = """import jax
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(5, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
2026-01-13 00:36:14 +00:00
mode=TestingMode.PERFORMANCE,
2026-01-12 23:45:29 +00:00
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_JAX_PERFORMANCE
assert result == expected
2026-01-12 23:45:29 +00:00
2026-01-13 00:36:14 +00:00
def test_all_frameworks_performance_mode(self, tmp_path: Path) -> None:
"""Test instrumentation with PyTorch, TensorFlow, and JAX imports in PERFORMANCE mode."""
2026-01-12 23:45:29 +00:00
code = """import torch
import tensorflow
import jax
from mymodule import my_function
def test_my_function():
result = my_function(1, 2)
assert result == 3
"""
test_file = tmp_path / "test_example.py"
test_file.write_text(code)
func = FunctionToOptimize(
function_name="my_function",
parents=[],
file_path=Path("mymodule.py"),
)
2026-01-13 00:36:14 +00:00
success, instrumented_code = inject_profiling_into_existing_test(
2026-01-12 23:45:29 +00:00
test_path=test_file,
call_positions=[CodePosition(7, 13)],
function_to_optimize=func,
tests_project_root=tmp_path,
2026-01-13 00:36:14 +00:00
mode=TestingMode.PERFORMANCE,
2026-01-12 23:45:29 +00:00
)
2026-01-13 00:36:14 +00:00
result = normalize_instrumented_code(instrumented_code)
expected = EXPECTED_ALL_FRAMEWORKS_PERFORMANCE
assert result == expected