refactor: eliminate inline async decorator duplication and fix 10-column test gaps

Replace 218-line ASYNC_HELPER_INLINE_CODE string with shutil.copy2 of the
runtime decorator file. Update remaining test files for 10-column SQLite
schema (cpu_runtime). Add cpu_runtime assertions to async E2E tests.
This commit is contained in:
Kevin Turcios 2026-04-24 02:31:40 -05:00
parent eb6a0be717
commit 2fd9d06e28
6 changed files with 33 additions and 254 deletions

View file

@ -2,14 +2,16 @@
Provides ``AsyncCallInstrumenter`` for injecting ``CODEFLASH_CURRENT_LINE_ID`` Provides ``AsyncCallInstrumenter`` for injecting ``CODEFLASH_CURRENT_LINE_ID``
assignments before ``await`` calls, ``AsyncDecoratorAdder`` for adding assignments before ``await`` calls, ``AsyncDecoratorAdder`` for adding
async performance/behavior decorators via libcst, the inline async helper async performance/behavior decorators via libcst, and high-level functions
code, and high-level functions for instrumenting async test and source files. for instrumenting async test and source files.
""" """
from __future__ import annotations from __future__ import annotations
import ast import ast
import logging import logging
import shutil
from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import libcst as cst import libcst as cst
@ -26,8 +28,6 @@ from ._instrument_core import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from pathlib import Path
from ..test_discovery.models import CodePosition from ..test_discovery.models import CodePosition
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -126,25 +126,6 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
node.body = new_body node.body = new_body
return node return node
def _instrument_statement(
self, stmt: ast.stmt, _node_name: str
) -> tuple[ast.stmt, bool]:
"""Check whether a statement contains an awaited target call."""
for node in ast.walk(stmt):
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
return (
stmt,
True,
) # Return original statement but signal we added env var
return stmt, False
def _is_target_call(self, call_node: ast.Call) -> bool: def _is_target_call(self, call_node: ast.Call) -> bool:
"""Check if this call node is calling our target async function.""" """Check if this call node is calling our target async function."""
if isinstance(call_node.func, ast.Name): if isinstance(call_node.func, ast.Name):
@ -303,228 +284,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
return False return False
ASYNC_HELPER_INLINE_CODE = """import asyncio
import gc
import os
import sqlite3
import time
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
import dill as pickle
def get_run_tmp_file(file_path):
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
def extract_test_context_from_env():
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set"
" - ensure tests are run through"
" codeflash test runner"
)
def codeflash_behavior_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
(test_module_name, test_class_name,
test_name) = extract_test_context_from_env()
test_id = (
f"{test_module_name}:{test_class_name}"
f":{test_name}:{line_id}:{loop_index}"
)
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (
(test_class_name + ".") if test_class_name else ""
)
test_stdout_tag = (
f"{test_module_name}:{class_prefix}"
f"{test_name}:{function_name}"
f":{loop_index}:{invocation_id}"
)
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get(
"CODEFLASH_TEST_ITERATION", "0"
)
db_path = get_run_tmp_file(
Path(f"test_return_values_{iteration}.sqlite")
)
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results"
" (test_module_path TEXT,"
" test_class_name TEXT,"
" test_function_name TEXT,"
" function_getting_tested TEXT,"
" loop_index INTEGER,"
" iteration_id TEXT,"
" runtime INTEGER,"
" return_value BLOB,"
" verification_type TEXT,"
" cpu_runtime INTEGER)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int(
(loop.time() - counter) * 1_000_000_000
)
except Exception as e:
codeflash_duration = int(
(loop.time() - counter) * 1_000_000_000
)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = (
pickle.dumps(exception) if exception
else pickle.dumps(
(args, kwargs, return_value)
)
)
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
"function_call",
0,
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_performance_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
(test_module_name, test_class_name,
test_name) = extract_test_context_from_env()
test_id = (
f"{test_module_name}:{test_class_name}"
f":{test_name}:{line_id}:{loop_index}"
)
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (
(test_class_name + ".") if test_class_name else ""
)
test_stdout_tag = (
f"{test_module_name}:{class_prefix}"
f"{test_name}:{function_name}"
f":{loop_index}:{invocation_id}"
)
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_concurrency_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
function_name = func.__name__
concurrency_factor = int(os.environ.get(
"CODEFLASH_CONCURRENCY_FACTOR", "10"
))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
tag = (
f"{test_module_name}:{test_class_name}"
f":{test_function}:{function_name}"
f":{loop_index}"
)
print(
f"!@######CONC:{tag}"
f":{sequential_time}:{concurrent_time}"
f":{concurrency_factor}######@!"
)
return result
return async_wrapper
"""
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py" ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
_RUNTIME_DECORATOR_PATH = (
Path(__file__).resolve().parent.parent
/ "runtime"
/ "_codeflash_wrap_decorator.py"
)
def get_decorator_name_for_mode( def get_decorator_name_for_mode(
mode: TestingMode, mode: TestingMode,
@ -540,10 +307,10 @@ def get_decorator_name_for_mode(
def write_async_helper_file( def write_async_helper_file(
target_dir: Path, target_dir: Path,
) -> Path: ) -> Path:
"""Write the async decorator helper file to the target directory.""" """Copy the runtime async decorator module to the target directory."""
helper_path = target_dir / ASYNC_HELPER_FILENAME helper_path = target_dir / ASYNC_HELPER_FILENAME
if not helper_path.exists(): if not helper_path.exists():
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8") shutil.copy2(_RUNTIME_DECORATOR_PATH, helper_path)
return helper_path return helper_path

View file

@ -28,9 +28,6 @@ from ..runtime._codeflash_wrap_decorator import (
from ._instrument_async import ( from ._instrument_async import (
ASYNC_HELPER_FILENAME as ASYNC_HELPER_FILENAME, # noqa: PLC0414 ASYNC_HELPER_FILENAME as ASYNC_HELPER_FILENAME, # noqa: PLC0414
) )
from ._instrument_async import (
ASYNC_HELPER_INLINE_CODE as ASYNC_HELPER_INLINE_CODE, # noqa: PLC0414
)
from ._instrument_async import ( from ._instrument_async import (
AsyncCallInstrumenter as AsyncCallInstrumenter, # noqa: PLC0414 AsyncCallInstrumenter as AsyncCallInstrumenter, # noqa: PLC0414
) )

View file

@ -155,6 +155,7 @@ async def test_async_sort():
assert results_list[0].id.test_function_name == "test_async_sort" assert results_list[0].id.test_function_name == "test_async_sort"
assert results_list[0].did_pass assert results_list[0].did_pass
assert results_list[0].runtime is None or results_list[0].runtime >= 0 assert results_list[0].runtime is None or results_list[0].runtime >= 0
assert isinstance(results_list[0].cpu_runtime, int)
expected_stdout = "codeflash stdout: Async sorting list\nresult: [0, 1, 2, 3, 4, 5]\n" expected_stdout = "codeflash stdout: Async sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
assert expected_stdout == results_list[0].stdout assert expected_stdout == results_list[0].stdout
@ -162,6 +163,7 @@ async def test_async_sort():
assert results_list[1].id.function_getting_tested == "async_sorter" assert results_list[1].id.function_getting_tested == "async_sorter"
assert results_list[1].id.test_function_name == "test_async_sort" assert results_list[1].id.test_function_name == "test_async_sort"
assert results_list[1].did_pass assert results_list[1].did_pass
assert isinstance(results_list[1].cpu_runtime, int)
expected_stdout2 = "codeflash stdout: Async sorting list\nresult: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]\n" expected_stdout2 = "codeflash stdout: Async sorting list\nresult: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]\n"
assert expected_stdout2 == results_list[1].stdout assert expected_stdout2 == results_list[1].stdout
@ -300,6 +302,7 @@ async def test_async_class_sort():
assert sorter_result.id.test_function_name == "test_async_class_sort" assert sorter_result.id.test_function_name == "test_async_class_sort"
assert sorter_result.did_pass assert sorter_result.did_pass
assert sorter_result.runtime is None or sorter_result.runtime >= 0 assert sorter_result.runtime is None or sorter_result.runtime >= 0
assert isinstance(sorter_result.cpu_runtime, int)
expected_stdout = ( expected_stdout = (
"codeflash stdout: AsyncBubbleSorter.sorter() called\n" "codeflash stdout: AsyncBubbleSorter.sorter() called\n"
@ -308,6 +311,7 @@ async def test_async_class_sort():
assert ".__init__" in init_result.id.function_getting_tested assert ".__init__" in init_result.id.function_getting_tested
assert init_result.did_pass assert init_result.did_pass
assert isinstance(init_result.cpu_runtime, int)
finally: finally:
fto_path.write_text(original_code, "utf-8") fto_path.write_text(original_code, "utf-8")

View file

@ -98,6 +98,7 @@ class TestAsyncWrapperSQLiteValidation:
runtime, runtime,
return_value_blob, return_value_blob,
verification_type, verification_type,
cpu_runtime,
) = row ) = row
assert test_module_path == __name__ assert test_module_path == __name__
@ -111,6 +112,7 @@ class TestAsyncWrapperSQLiteValidation:
) and iteration_id.endswith("_0") ) and iteration_id.endswith("_0")
assert runtime > 0 assert runtime > 0
assert verification_type == VerificationType.FUNCTION_CALL.value assert verification_type == VerificationType.FUNCTION_CALL.value
assert isinstance(cpu_runtime, int)
unpickled_data = pickle.loads(return_value_blob) unpickled_data = pickle.loads(return_value_blob)
args, kwargs, return_val = unpickled_data args, kwargs, return_val = unpickled_data
@ -324,6 +326,7 @@ class TestAsyncWrapperSQLiteValidation:
(6, "runtime", "INTEGER", 0, None, 0), (6, "runtime", "INTEGER", 0, None, 0),
(7, "return_value", "BLOB", 0, None, 0), (7, "return_value", "BLOB", 0, None, 0),
(8, "verification_type", "TEXT", 0, None, 0), (8, "verification_type", "TEXT", 0, None, 0),
(9, "cpu_runtime", "INTEGER", 0, None, 0),
] ]
assert columns == expected_columns assert columns == expected_columns

View file

@ -845,6 +845,7 @@ def test_async_behavior_decorator_return_values_and_test_ids():
runtime, runtime,
return_value_blob, return_value_blob,
verification_type, verification_type,
cpu_runtime,
) = row ) = row
assert test_module == "test_module", ( assert test_module == "test_module", (

View file

@ -17,7 +17,6 @@ from codeflash_python._model import (
from codeflash_python.test_discovery.models import CodePosition from codeflash_python.test_discovery.models import CodePosition
from codeflash_python.testing._instrumentation import ( from codeflash_python.testing._instrumentation import (
ASYNC_HELPER_FILENAME, ASYNC_HELPER_FILENAME,
ASYNC_HELPER_INLINE_CODE,
AsyncCallInstrumenter, AsyncCallInstrumenter,
AsyncDecoratorAdder, AsyncDecoratorAdder,
FunctionCallNodeArguments, FunctionCallNodeArguments,
@ -587,16 +586,24 @@ class TestWriteAsyncHelperFile:
class TestAsyncHelperConstants: class TestAsyncHelperConstants:
"""ASYNC_HELPER_FILENAME and ASYNC_HELPER_INLINE_CODE constants.""" """ASYNC_HELPER_FILENAME and runtime decorator source."""
def test_filename_value(self) -> None: def test_filename_value(self) -> None:
"""ASYNC_HELPER_FILENAME has the expected value.""" """ASYNC_HELPER_FILENAME has the expected value."""
assert "codeflash_async_wrapper.py" == ASYNC_HELPER_FILENAME assert "codeflash_async_wrapper.py" == ASYNC_HELPER_FILENAME
def test_inline_code_nonempty(self) -> None: def test_runtime_decorator_is_self_contained(self) -> None:
"""ASYNC_HELPER_INLINE_CODE is a non-empty string.""" """Runtime decorator file has no internal codeflash imports."""
assert isinstance(ASYNC_HELPER_INLINE_CODE, str) from codeflash_python.testing._instrument_async import (
assert len(ASYNC_HELPER_INLINE_CODE) > 0 _RUNTIME_DECORATOR_PATH,
)
source = _RUNTIME_DECORATOR_PATH.read_text("utf-8")
assert _RUNTIME_DECORATOR_PATH.exists()
for line in source.splitlines():
stripped = line.strip()
if stripped.startswith(("import ", "from ")):
assert "codeflash_python" not in stripped
class TestSortImports: class TestSortImports: