feat: rewrite async instrumentation to use SQLite-only data path and contextvars

Replace the fragile stdout tag protocol with a unified SQLite table
(async_results) for all 3 async test modes. The new runtime decorators
write behavior, performance, and concurrency results directly to the DB
with zero stdout output. Test-file instrumentation now injects
_codeflash_call_site.set() (contextvar) instead of os.environ assignments,
which is correct for async task isolation.

New modules:
- runtime/_codeflash_async_decorators.py: self-contained decorators
- testing/_async_data_parser.py: SQLite reader replacing stdout parsing

Both at 100% test coverage (42 new tests).
This commit is contained in:
Kevin Turcios 2026-04-24 03:44:06 -05:00
parent 24199efc63
commit 629d7f9f08
6 changed files with 1727 additions and 31 deletions

View file

@ -0,0 +1,359 @@
"""
Self-contained async instrumentation decorators for codeflash testing.
This module is copied to test project directories at runtime, so it
must have zero imports from the codeflash_python package.
"""
# ruff: noqa: BLE001
from __future__ import annotations
import asyncio
import atexit
import contextvars
import gc
import os
import sqlite3
import time
from enum import Enum
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, TypeVar
import dill as pickle
class VerificationType(str, Enum):
"""Type of correctness verification for captured test data."""
FUNCTION_CALL = "function_call"
INIT_STATE_FTO = "init_state_fto"
INIT_STATE_HELPER = "init_state_helper"
F = TypeVar("F", bound=Callable[..., Any])
def get_run_tmp_file(file_path: Path) -> Path:
"""Return a path inside a persistent per-run temporary directory."""
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory( # type: ignore[attr-defined]
prefix="codeflash_"
)
return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined]
def extract_test_context_from_env() -> tuple[str, str | None, str]:
"""
Read test module, class, and function names from env vars.
Raises RuntimeError when required variables are missing.
"""
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (
test_module,
test_class or None,
test_function,
)
raise RuntimeError( # noqa: TRY003
"Test context environment variables not set" # noqa: EM101
)
_codeflash_call_site: contextvars.ContextVar[str] = contextvars.ContextVar(
"codeflash_call_site", default=""
)
_CREATE_TABLE_SQL = (
"CREATE TABLE IF NOT EXISTS async_results ("
"test_module_path TEXT NOT NULL, "
"test_class_name TEXT, "
"test_function_name TEXT NOT NULL, "
"function_getting_tested TEXT NOT NULL, "
"loop_index INTEGER NOT NULL, "
"invocation_id TEXT NOT NULL, "
"mode TEXT NOT NULL, "
"wall_time_ns INTEGER NOT NULL, "
"return_value BLOB, "
"verification_type TEXT, "
"sequential_time_ns INTEGER, "
"concurrent_time_ns INTEGER, "
"concurrency_factor INTEGER"
")"
)
_connections: dict[str, sqlite3.Connection] = {}
def _get_async_db(
db_path: Path,
) -> tuple[sqlite3.Connection, sqlite3.Cursor]:
"""Get or create a cached SQLite connection."""
key = str(db_path)
if key not in _connections:
conn = sqlite3.connect(db_path)
conn.execute(_CREATE_TABLE_SQL)
_connections[key] = conn
conn = _connections[key]
return conn, conn.cursor()
def _close_all_connections() -> None:
"""Commit and close all cached connections."""
for conn in _connections.values():
try:
conn.commit()
conn.close()
except Exception: # noqa: PERF203, S110
pass
_connections.clear()
atexit.register(_close_all_connections)
def codeflash_behavior_async(func: F) -> F:
"""
Capture async return values and timing for behavioral tests.
Results are written to the async_results SQLite table.
"""
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
function_name = func.__name__
call_site = _codeflash_call_site.get()
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
(
test_module_name,
test_class_name,
test_name,
) = extract_test_context_from_env()
test_id = (
f"{test_module_name}:{test_class_name}"
f":{test_name}:{call_site}:{loop_index}"
)
if not hasattr(wrapper, "index"):
wrapper.index = {} # type: ignore[attr-defined]
if test_id in wrapper.index: # type: ignore[attr-defined]
wrapper.index[test_id] += 1 # type: ignore[attr-defined]
else:
wrapper.index[test_id] = 0 # type: ignore[attr-defined]
call_index = wrapper.index[test_id] # type: ignore[attr-defined]
invocation_id = f"{call_site}_{call_index}"
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
wall_time = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
wall_time = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
pickled = (
pickle.dumps(exception)
if exception
else pickle.dumps((args, kwargs, return_value))
)
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
"behavior",
wall_time,
pickled,
VerificationType.FUNCTION_CALL.value,
None,
None,
None,
),
)
conn.commit()
if exception:
raise exception
return return_value
return wrapper # type: ignore[return-value]
def codeflash_performance_async(func: F) -> F:
"""
Measure async execution time for performance tests.
Results are written to the async_results SQLite table.
"""
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
call_site = _codeflash_call_site.get()
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
(
test_module_name,
test_class_name,
test_name,
) = extract_test_context_from_env()
test_id = (
f"{test_module_name}:{test_class_name}"
f":{test_name}:{call_site}:{loop_index}"
)
if not hasattr(wrapper, "index"):
wrapper.index = {} # type: ignore[attr-defined]
if test_id in wrapper.index: # type: ignore[attr-defined]
wrapper.index[test_id] += 1 # type: ignore[attr-defined]
else:
wrapper.index[test_id] = 0 # type: ignore[attr-defined]
call_index = wrapper.index[test_id] # type: ignore[attr-defined]
invocation_id = f"{call_site}_{call_index}"
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
wall_time = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
wall_time = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
func.__name__,
loop_index,
invocation_id,
"performance",
wall_time,
None,
None,
None,
None,
None,
),
)
conn.commit()
if exception:
raise exception
return return_value
return wrapper # type: ignore[return-value]
def codeflash_concurrency_async(func: F) -> F:
"""
Measure concurrent vs sequential execution for async functions.
Results are written to the async_results SQLite table.
"""
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
function_name = func.__name__
concurrency_factor = int(
os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")
)
(
test_module_name,
test_class_name,
test_name,
) = extract_test_context_from_env()
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "0"))
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
"",
"concurrency",
0,
None,
None,
sequential_time,
concurrent_time,
concurrency_factor,
),
)
conn.commit()
return result
return wrapper # type: ignore[return-value]
__all__ = [
"VerificationType",
"_codeflash_call_site",
"codeflash_behavior_async",
"codeflash_concurrency_async",
"codeflash_performance_async",
"extract_test_context_from_env",
"get_run_tmp_file",
]

View file

@ -0,0 +1,257 @@
"""Async test result parsing from SQLite databases."""
from __future__ import annotations
import logging
import sqlite3
from typing import TYPE_CHECKING
from .._model import VerificationType
from ..test_discovery.models import TestType
from ._path_resolution import file_path_from_module_name
from .models import FunctionTestInvocation, InvocationId, TestResults
if TYPE_CHECKING:
from pathlib import Path
from ..benchmarking.models import ConcurrencyMetrics
from .models import TestConfig, TestFiles
log = logging.getLogger(__name__)
_BEHAVIOR_QUERY = (
"SELECT test_module_path, test_class_name,"
" test_function_name, function_getting_tested,"
" loop_index, invocation_id, wall_time_ns,"
" return_value, verification_type"
" FROM async_results"
" WHERE mode = 'behavior'"
)
_THROUGHPUT_QUERY = (
"SELECT COUNT(*) FROM async_results"
" WHERE function_getting_tested = ?"
" AND mode = 'performance'"
)
_CONCURRENCY_QUERY = (
"SELECT sequential_time_ns, concurrent_time_ns,"
" concurrency_factor"
" FROM async_results"
" WHERE function_getting_tested = ?"
" AND mode = 'concurrency'"
)
def parse_async_behavior_results(
sqlite_file_path: Path,
test_files: TestFiles,
test_config: TestConfig,
) -> TestResults:
"""Parse behavior test results from the async results database."""
test_results = TestResults()
if not sqlite_file_path.exists():
log.warning(
"No async test results for %s found.",
sqlite_file_path,
)
return test_results
db: sqlite3.Connection | None = None
try:
db = sqlite3.connect(sqlite_file_path)
cur = db.cursor()
data = cur.execute(_BEHAVIOR_QUERY).fetchall()
except (sqlite3.Error, OSError):
log.warning(
"Failed to parse async test results from %s.",
sqlite_file_path,
exc_info=True,
)
if db is not None:
db.close()
return test_results
finally:
if db is not None:
db.close()
for row in data:
_process_behavior_row(row, test_files, test_config, test_results)
return test_results
def _process_behavior_row(
val: tuple[object, ...],
test_files: TestFiles,
test_config: TestConfig,
test_results: TestResults,
) -> None:
"""Process a single behavior row from the async results table."""
try:
_process_behavior_row_inner(val, test_files, test_config, test_results)
except Exception:
log.exception("Failed to parse async behavior result")
def _process_behavior_row_inner(
val: tuple[object, ...],
test_files: TestFiles,
test_config: TestConfig,
test_results: TestResults,
) -> None:
"""Inner processing for a single async behavior row."""
test_module_path = val[0]
test_class_name = val[1] or None
test_function_name = val[2] or None
function_getting_tested = val[3]
loop_index = val[4]
invocation_id = val[5]
wall_time_ns = val[6]
verification_type = val[8]
test_file_path = file_path_from_module_name(
test_module_path, # type: ignore[arg-type]
test_config.tests_project_rootdir,
)
if verification_type in {
VerificationType.INIT_STATE_FTO,
VerificationType.INIT_STATE_HELPER,
}:
test_type: TestType = TestType.INIT_STATE_TEST
else:
found = test_files.get_test_type_by_original_file_path(
test_file_path,
)
if found is None:
found = test_files.get_test_type_by_instrumented_file_path(
test_file_path,
)
if found is None:
log.debug(
"Skipping async result for %s: could not determine test type",
test_function_name,
)
return
test_type = found
ret_val = None
if loop_index == 1 and val[7]:
import dill as pickle # noqa: PLC0415
try:
ret_val = (pickle.loads(val[7]),) # noqa: S301
except Exception: # noqa: BLE001
log.debug(
"Failed to deserialize return value for %s",
test_function_name,
exc_info=True,
)
return
test_results.add(
FunctionTestInvocation(
loop_index=loop_index, # type: ignore[arg-type]
id=InvocationId(
test_module_path=test_module_path, # type: ignore[arg-type]
test_class_name=test_class_name, # type: ignore[arg-type]
test_function_name=test_function_name, # type: ignore[arg-type]
function_getting_tested=function_getting_tested, # type: ignore[arg-type]
iteration_id=invocation_id, # type: ignore[arg-type]
),
file_name=test_file_path,
did_pass=True,
runtime=wall_time_ns, # type: ignore[arg-type]
test_framework=test_config.test_framework,
test_type=test_type,
return_value=ret_val,
cpu_runtime=0,
timed_out=False,
verification_type=(
VerificationType(verification_type)
if verification_type
else None
),
),
)
def calculate_async_throughput(
sqlite_file_path: Path,
function_name: str,
) -> int:
"""Count completed async function executions from the results database."""
if not sqlite_file_path.exists():
return 0
db: sqlite3.Connection | None = None
try:
db = sqlite3.connect(sqlite_file_path)
cur = db.cursor()
row = cur.execute(_THROUGHPUT_QUERY, (function_name,)).fetchone()
return int(row[0]) if row else 0
except (sqlite3.Error, OSError):
log.warning(
"Failed to read async throughput from %s.",
sqlite_file_path,
exc_info=True,
)
return 0
finally:
if db is not None:
db.close()
def parse_async_concurrency_metrics(
sqlite_file_path: Path,
function_name: str,
) -> ConcurrencyMetrics | None:
"""Parse concurrency benchmark metrics from the results database."""
if not sqlite_file_path.exists():
return None
db: sqlite3.Connection | None = None
try:
db = sqlite3.connect(sqlite_file_path)
cur = db.cursor()
rows = cur.execute(_CONCURRENCY_QUERY, (function_name,)).fetchall()
except (sqlite3.Error, OSError):
log.warning(
"Failed to read async concurrency metrics from %s.",
sqlite_file_path,
exc_info=True,
)
if db is not None:
db.close()
return None
finally:
if db is not None:
db.close()
if not rows:
return None
total_seq = 0
total_conc = 0
factor = rows[0][2]
for row in rows:
total_seq += row[0]
total_conc += row[1]
count = len(rows)
avg_seq = total_seq / count
avg_conc = total_conc / count
ratio = 1.0 if avg_conc == 0 else avg_seq / avg_conc
from ..benchmarking.models import ( # noqa: PLC0415
ConcurrencyMetrics,
)
return ConcurrencyMetrics(
sequential_time_ns=int(avg_seq),
concurrent_time_ns=int(avg_conc),
concurrency_factor=factor,
concurrency_ratio=ratio,
)

View file

@ -1,9 +1,9 @@
"""Async-specific instrumentation: AST transformers, decorators, and helpers.
Provides ``AsyncCallInstrumenter`` for injecting ``CODEFLASH_CURRENT_LINE_ID``
assignments before ``await`` calls, ``AsyncDecoratorAdder`` for adding
async performance/behavior decorators via libcst, and high-level functions
for instrumenting async test and source files.
Provides ``AsyncCallInstrumenter`` for injecting ``_codeflash_call_site``
contextvar assignments before ``await`` calls, ``AsyncDecoratorAdder`` for
adding async performance/behavior decorators via libcst, and high-level
functions for instrumenting async test and source files.
"""
from __future__ import annotations
@ -71,8 +71,7 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
self,
node: ast.AsyncFunctionDef | ast.FunctionDef,
) -> ast.AsyncFunctionDef | ast.FunctionDef:
"""Add CODEFLASH_CURRENT_LINE_ID assignments before target await calls."""
# Initialize counter for this test function
"""Add _codeflash_call_site.set() calls before target await calls."""
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0
@ -85,24 +84,26 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
current_call_index = self.async_call_counter[node.name]
self.async_call_counter[node.name] += 1
env_assignment = ast.Assign(
targets=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()),
attr="environ",
call_site_set = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="_codeflash_call_site",
ctx=ast.Load(),
),
slice=ast.Constant(
value="CODEFLASH_CURRENT_LINE_ID"
attr="set",
ctx=ast.Load(),
),
args=[
ast.Constant(
value=f"{current_call_index}",
),
ctx=ast.Store(),
)
],
value=ast.Constant(value=f"{current_call_index}"),
],
keywords=[],
),
lineno=stmt.lineno,
)
new_body.append(env_assignment)
new_body.append(call_site_set)
self.did_instrument = True
new_body.append(stmt)
@ -236,7 +237,7 @@ ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
_RUNTIME_DECORATOR_PATH = (
Path(__file__).resolve().parent.parent
/ "runtime"
/ "_codeflash_wrap_decorator.py"
/ "_codeflash_async_decorators.py"
)
@ -289,7 +290,13 @@ def inject_async_profiling_into_existing_test(
if not async_instrumenter.did_instrument:
return False, None
new_imports = [ast.Import(names=[ast.alias(name="os")])]
new_imports = [
ast.ImportFrom(
module="codeflash_async_wrapper",
names=[ast.alias(name="_codeflash_call_site")],
level=0,
),
]
tree.body = [*new_imports, *tree.body]
return True, sort_imports(ast.unparse(tree), float_to_top=True)

View file

@ -0,0 +1,562 @@
"""Tests for the async SQLite data parser."""
from __future__ import annotations
import sqlite3
from pathlib import Path
from unittest.mock import MagicMock
import dill as pickle
import pytest
from codeflash_python.runtime._codeflash_async_decorators import (
_CREATE_TABLE_SQL,
VerificationType,
)
from codeflash_python.testing._async_data_parser import (
calculate_async_throughput,
parse_async_behavior_results,
parse_async_concurrency_metrics,
)
@pytest.fixture(name="test_config")
def _test_config(tmp_path):
"""Minimal TestConfig mock."""
cfg = MagicMock()
cfg.tests_project_rootdir = tmp_path
cfg.test_framework = "pytest"
return cfg
@pytest.fixture(name="test_files")
def _test_files(tmp_path):
"""Minimal TestFiles mock."""
from codeflash_python.test_discovery.models import TestType
tf = MagicMock()
tf.get_test_type_by_original_file_path.return_value = (
TestType.EXISTING_UNIT_TEST
)
tf.get_test_type_by_instrumented_file_path.return_value = None
return tf
def _create_async_db(
db_path: Path,
rows: list[tuple],
) -> None:
"""Create an async_results SQLite DB with the given rows."""
conn = sqlite3.connect(db_path)
conn.execute(_CREATE_TABLE_SQL)
for row in rows:
conn.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
row,
)
conn.commit()
conn.close()
class TestParseAsyncBehaviorResults:
"""parse_async_behavior_results reads behavior rows from SQLite."""
def test_empty_when_file_missing(
self, tmp_path, test_files, test_config
) -> None:
"""Returns empty TestResults when file does not exist."""
results = parse_async_behavior_results(
tmp_path / "nonexistent.sqlite",
test_files,
test_config,
)
assert 0 == len(list(results))
def test_reads_behavior_rows(
self, tmp_path, test_files, test_config
) -> None:
"""Parses behavior rows into FunctionTestInvocation objects."""
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
module_path = "test_module"
pickled = pickle.dumps(((1, 2), {}, 3))
_create_async_db(
db_path,
[
(
module_path,
None,
"test_fn",
"target_func",
1,
"0_0",
"behavior",
1000,
pickled,
VerificationType.FUNCTION_CALL.value,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(
db_path, test_files, test_config
)
invocations = list(results)
assert 1 == len(invocations)
inv = invocations[0]
assert "target_func" == inv.id.function_getting_tested
assert 1000 == inv.runtime
assert inv.return_value is not None
def test_skips_non_behavior_rows(
self, tmp_path, test_files, test_config
) -> None:
"""Ignores performance and concurrency rows."""
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"test_module",
None,
"test_fn",
"target",
1,
"0_0",
"performance",
1000,
None,
None,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(
db_path, test_files, test_config
)
assert 0 == len(list(results))
def test_handles_corrupt_db(
self, tmp_path, test_files, test_config
) -> None:
"""Returns empty results for a corrupted database file."""
db_path = tmp_path / "corrupt.sqlite"
db_path.write_text("not a sqlite file", encoding="utf-8")
results = parse_async_behavior_results(
db_path, test_files, test_config
)
assert 0 == len(list(results))
class TestCalculateAsyncThroughput:
"""calculate_async_throughput counts performance rows."""
def test_returns_zero_when_file_missing(self, tmp_path) -> None:
"""Returns 0 when SQLite file does not exist."""
result = calculate_async_throughput(
tmp_path / "nonexistent.sqlite",
"func",
)
assert 0 == result
def test_counts_performance_rows(self, tmp_path) -> None:
"""Counts rows matching function name and performance mode."""
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"mod",
None,
"test_fn",
"target",
1,
"0_0",
"performance",
1000,
None,
None,
None,
None,
None,
),
(
"mod",
None,
"test_fn",
"target",
1,
"0_1",
"performance",
2000,
None,
None,
None,
None,
None,
),
(
"mod",
None,
"test_fn",
"other_func",
1,
"0_0",
"performance",
3000,
None,
None,
None,
None,
None,
),
(
"mod",
None,
"test_fn",
"target",
1,
"0_0",
"behavior",
4000,
None,
None,
None,
None,
None,
),
],
)
result = calculate_async_throughput(db_path, "target")
assert 2 == result
def test_returns_zero_for_no_matches(self, tmp_path) -> None:
"""Returns 0 when no rows match the function name."""
db_path = tmp_path / "async_results.sqlite"
_create_async_db(db_path, [])
result = calculate_async_throughput(db_path, "nonexistent")
assert 0 == result
class TestParseAsyncConcurrencyMetrics:
"""parse_async_concurrency_metrics reads concurrency rows."""
def test_returns_none_when_file_missing(self, tmp_path) -> None:
"""Returns None when SQLite file does not exist."""
result = parse_async_concurrency_metrics(
tmp_path / "nonexistent.sqlite",
"func",
)
assert result is None
def test_parses_concurrency_metrics(self, tmp_path) -> None:
"""Computes averages from multiple concurrency rows."""
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"mod",
None,
"test_fn",
"target",
1,
"",
"concurrency",
0,
None,
None,
100_000,
50_000,
10,
),
(
"mod",
None,
"test_fn",
"target",
1,
"",
"concurrency",
0,
None,
None,
200_000,
100_000,
10,
),
],
)
metrics = parse_async_concurrency_metrics(db_path, "target")
assert metrics is not None
assert 150_000 == metrics.sequential_time_ns
assert 75_000 == metrics.concurrent_time_ns
assert 10 == metrics.concurrency_factor
assert 2.0 == metrics.concurrency_ratio
def test_returns_none_for_wrong_function(self, tmp_path) -> None:
"""Returns None when no rows match the function name."""
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"mod",
None,
"test_fn",
"other_func",
1,
"",
"concurrency",
0,
None,
None,
100_000,
50_000,
10,
),
],
)
result = parse_async_concurrency_metrics(db_path, "target")
assert result is None
def test_handles_zero_concurrent_time(self, tmp_path) -> None:
"""Returns ratio 1.0 when concurrent time is zero."""
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"mod",
None,
"test_fn",
"target",
1,
"",
"concurrency",
0,
None,
None,
100_000,
0,
5,
),
],
)
metrics = parse_async_concurrency_metrics(db_path, "target")
assert metrics is not None
assert 1.0 == metrics.concurrency_ratio
def test_corrupt_db_returns_none(self, tmp_path) -> None:
"""Returns None for a corrupted database file."""
db_path = tmp_path / "corrupt.sqlite"
db_path.write_text("not a sqlite file", encoding="utf-8")
result = parse_async_concurrency_metrics(db_path, "target")
assert result is None
class TestCalculateAsyncThroughputErrorHandling:
"""Error handling in calculate_async_throughput."""
def test_corrupt_db_returns_zero(self, tmp_path) -> None:
"""Returns 0 for a corrupted database file."""
db_path = tmp_path / "corrupt.sqlite"
db_path.write_text("not a sqlite file", encoding="utf-8")
result = calculate_async_throughput(db_path, "func")
assert 0 == result
class TestParseAsyncBehaviorEdgeCases:
"""Edge cases for parse_async_behavior_results."""
def test_init_state_verification_type(
self, tmp_path, test_files, test_config
) -> None:
"""Rows with INIT_STATE_FTO get INIT_STATE_TEST type."""
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"test_module",
None,
"test_fn",
"target",
1,
"0_0",
"behavior",
1000,
pickle.dumps(((1,), {}, 2)),
VerificationType.INIT_STATE_FTO.value,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(
db_path, test_files, test_config
)
invocations = list(results)
assert 1 == len(invocations)
def test_test_type_fallback_to_instrumented(
self, tmp_path, test_config
) -> None:
"""Falls back to instrumented file path for test type lookup."""
from codeflash_python.test_discovery.models import TestType
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
tf = MagicMock()
tf.get_test_type_by_original_file_path.return_value = None
tf.get_test_type_by_instrumented_file_path.return_value = (
TestType.EXISTING_UNIT_TEST
)
_create_async_db(
db_path,
[
(
"test_module",
None,
"test_fn",
"target",
1,
"0_0",
"behavior",
1000,
pickle.dumps(((1,), {}, 2)),
VerificationType.FUNCTION_CALL.value,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(db_path, tf, test_config)
assert 1 == len(list(results))
def test_test_type_not_found_skips_row(
self, tmp_path, test_config
) -> None:
"""Skips rows when test type cannot be determined."""
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
tf = MagicMock()
tf.get_test_type_by_original_file_path.return_value = None
tf.get_test_type_by_instrumented_file_path.return_value = None
_create_async_db(
db_path,
[
(
"test_module",
None,
"test_fn",
"target",
1,
"0_0",
"behavior",
1000,
pickle.dumps(((1,), {}, 2)),
VerificationType.FUNCTION_CALL.value,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(db_path, tf, test_config)
assert 0 == len(list(results))
def test_bad_pickle_data_skips_row(
self, tmp_path, test_files, test_config
) -> None:
"""Skips rows with unpicklable return value data."""
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
_create_async_db(
db_path,
[
(
"test_module",
None,
"test_fn",
"target",
1,
"0_0",
"behavior",
1000,
b"definitely not valid pickle",
VerificationType.FUNCTION_CALL.value,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(
db_path, test_files, test_config
)
assert 0 == len(list(results))
def test_outer_exception_handler(self, tmp_path, test_config) -> None:
"""Outer exception handler catches errors from inner processing."""
db_path = tmp_path / "test_module.py"
db_path.write_text("", encoding="utf-8")
db_path = tmp_path / "async_results.sqlite"
tf = MagicMock()
tf.get_test_type_by_original_file_path.side_effect = RuntimeError(
"boom"
)
_create_async_db(
db_path,
[
(
"test_module",
None,
"test_fn",
"target",
1,
"0_0",
"behavior",
1000,
None,
VerificationType.FUNCTION_CALL.value,
None,
None,
None,
),
],
)
results = parse_async_behavior_results(db_path, tf, test_config)
assert 0 == len(list(results))

View file

@ -0,0 +1,508 @@
"""Tests for the new self-contained async instrumentation decorators."""
from __future__ import annotations
import asyncio
import contextvars
import os
import sqlite3
import sys
from pathlib import Path
from unittest.mock import patch
import dill as pickle
import pytest
from codeflash_python.runtime._codeflash_async_decorators import (
VerificationType,
_close_all_connections,
_codeflash_call_site,
_connections,
_get_async_db,
codeflash_behavior_async,
codeflash_concurrency_async,
codeflash_performance_async,
extract_test_context_from_env,
get_run_tmp_file,
)
@pytest.fixture(name="env_setup")
def _env_setup(request, tmp_path):
"""Set up env vars and clean up after each test."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "0",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "",
"CODEFLASH_TEST_FUNCTION": request.node.name,
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
_close_all_connections()
@pytest.fixture(name="async_db_path")
def _async_db_path(env_setup):
"""Return the path where the async results DB will be written."""
iteration = env_setup["CODEFLASH_TEST_ITERATION"]
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
yield db_path
if db_path.exists():
db_path.unlink()
class TestExtractTestContextFromEnv:
"""extract_test_context_from_env reads CODEFLASH_TEST_* env vars."""
def test_returns_tuple(self) -> None:
"""Returns (module, class_or_none, function) from env."""
with patch.dict(
os.environ,
{
"CODEFLASH_TEST_MODULE": "mod",
"CODEFLASH_TEST_CLASS": "Cls",
"CODEFLASH_TEST_FUNCTION": "test_fn",
},
):
module, cls, fn = extract_test_context_from_env()
assert "mod" == module
assert "Cls" == cls
assert "test_fn" == fn
def test_class_none_when_empty(self) -> None:
"""Returns None for test class when env var is empty."""
with patch.dict(
os.environ,
{
"CODEFLASH_TEST_MODULE": "mod",
"CODEFLASH_TEST_CLASS": "",
"CODEFLASH_TEST_FUNCTION": "test_fn",
},
):
_, cls, _ = extract_test_context_from_env()
assert cls is None
def test_raises_when_module_missing(self) -> None:
"""Raises KeyError when CODEFLASH_TEST_MODULE is unset."""
env = os.environ.copy()
env.pop("CODEFLASH_TEST_MODULE", None)
with patch.dict(os.environ, env, clear=True), pytest.raises(KeyError):
extract_test_context_from_env()
class TestCodeflashCallSite:
"""_codeflash_call_site contextvar behavior."""
def test_default_value(self) -> None:
"""Default value is empty string."""
ctx = contextvars.copy_context()
assert "" == ctx.run(_codeflash_call_site.get)
def test_set_and_get(self) -> None:
"""Can set and retrieve a value."""
token = _codeflash_call_site.set("42")
assert "42" == _codeflash_call_site.get()
_codeflash_call_site.reset(token)
class TestGetAsyncDb:
"""_get_async_db connection caching."""
def test_creates_table(self, tmp_path) -> None:
"""Creates the async_results table on first connect."""
db_path = tmp_path / "test.sqlite"
conn, cur = _get_async_db(db_path)
cur.execute(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND name='async_results'"
)
assert cur.fetchone() is not None
conn.close()
_connections.pop(str(db_path), None)
def test_caches_connection(self, tmp_path) -> None:
"""Returns the same connection on repeated calls."""
db_path = tmp_path / "test.sqlite"
conn1, _ = _get_async_db(db_path)
conn2, _ = _get_async_db(db_path)
assert conn1 is conn2
conn1.close()
_connections.pop(str(db_path), None)
def test_close_all_connections(self, tmp_path) -> None:
"""_close_all_connections empties the cache."""
db_path = tmp_path / "test.sqlite"
_get_async_db(db_path)
assert 0 < len(_connections)
_close_all_connections()
assert 0 == len(_connections)
@pytest.mark.skipif(
sys.platform == "win32",
reason="pending support for asyncio on windows",
)
class TestBehaviorAsync:
"""codeflash_behavior_async decorator."""
@pytest.mark.asyncio
async def test_returns_correct_value(
self, env_setup, async_db_path
) -> None:
"""Decorated function returns the original return value."""
@codeflash_behavior_async
async def add(a: int, b: int) -> int:
return a + b
_codeflash_call_site.set("0")
result = await add(3, 4)
assert 7 == result
@pytest.mark.asyncio
async def test_writes_to_sqlite(self, env_setup, async_db_path) -> None:
"""Writes behavior result to async_results table."""
@codeflash_behavior_async
async def multiply(a: int, b: int) -> int:
return a * b
_codeflash_call_site.set("0")
await multiply(5, 6)
_close_all_connections()
assert async_db_path.exists()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM async_results")
rows = cur.fetchall()
assert 1 == len(rows)
row = rows[0]
assert "behavior" == row[6]
assert 0 < row[7]
data = pickle.loads(row[8])
args, kwargs, ret = data
assert (5, 6) == args
assert {} == kwargs
assert 30 == ret
assert VerificationType.FUNCTION_CALL.value == row[9]
con.close()
@pytest.mark.asyncio
async def test_exception_handling(self, env_setup, async_db_path) -> None:
"""Re-raises exceptions and stores them pickled."""
@codeflash_behavior_async
async def fail() -> None:
raise ValueError("boom")
_codeflash_call_site.set("0")
with pytest.raises(ValueError, match="boom"):
await fail()
_close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM async_results")
row = cur.fetchone()
exc = pickle.loads(row[0])
assert isinstance(exc, ValueError)
assert "boom" == str(exc)
con.close()
@pytest.mark.asyncio
async def test_no_stdout_output(
self, env_setup, async_db_path, capsys
) -> None:
"""Behavior decorator emits no stdout."""
@codeflash_behavior_async
async def noop() -> int:
return 1
_codeflash_call_site.set("0")
await noop()
captured = capsys.readouterr()
assert "" == captured.out
@pytest.mark.skipif(
sys.platform == "win32",
reason="pending support for asyncio on windows",
)
class TestPerformanceAsync:
"""codeflash_performance_async decorator."""
@pytest.mark.asyncio
async def test_returns_correct_value(
self, env_setup, async_db_path
) -> None:
"""Returns the original return value."""
@codeflash_performance_async
async def add(a: int, b: int) -> int:
return a + b
_codeflash_call_site.set("0")
result = await add(3, 4)
assert 7 == result
@pytest.mark.asyncio
async def test_writes_to_sqlite(self, env_setup, async_db_path) -> None:
"""Writes performance result with null return_value."""
@codeflash_performance_async
async def work() -> int:
return 42
_codeflash_call_site.set("0")
await work()
_close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM async_results")
rows = cur.fetchall()
assert 1 == len(rows)
row = rows[0]
assert "performance" == row[6]
assert 0 < row[7]
assert row[8] is None
assert row[9] is None
con.close()
@pytest.mark.asyncio
async def test_no_stdout_output(
self, env_setup, async_db_path, capsys
) -> None:
"""Performance decorator emits no stdout."""
@codeflash_performance_async
async def noop() -> int:
return 1
_codeflash_call_site.set("0")
await noop()
captured = capsys.readouterr()
assert "" == captured.out
@pytest.mark.skipif(
sys.platform == "win32",
reason="pending support for asyncio on windows",
)
class TestConcurrencyAsync:
"""codeflash_concurrency_async decorator."""
@pytest.mark.asyncio
async def test_returns_correct_value(
self, env_setup, async_db_path
) -> None:
"""Returns the result from sequential execution."""
@codeflash_concurrency_async
async def add(a: int, b: int) -> int:
return a + b
result = await add(3, 4)
assert 7 == result
@pytest.mark.asyncio
async def test_writes_concurrency_metrics(
self, env_setup, async_db_path
) -> None:
"""Writes sequential/concurrent timing to async_results."""
os.environ["CODEFLASH_CONCURRENCY_FACTOR"] = "3"
@codeflash_concurrency_async
async def work() -> int:
await asyncio.sleep(0.001)
return 42
await work()
_close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM async_results")
rows = cur.fetchall()
assert 1 == len(rows)
row = rows[0]
assert "concurrency" == row[6]
assert 0 == row[7]
assert 0 < row[10]
assert 0 < row[11]
assert 3 == row[12]
con.close()
os.environ.pop("CODEFLASH_CONCURRENCY_FACTOR", None)
@pytest.mark.asyncio
async def test_no_stdout_output(
self, env_setup, async_db_path, capsys
) -> None:
"""Concurrency decorator emits no stdout."""
os.environ["CODEFLASH_CONCURRENCY_FACTOR"] = "2"
@codeflash_concurrency_async
async def noop() -> int:
return 1
await noop()
captured = capsys.readouterr()
assert "" == captured.out
os.environ.pop("CODEFLASH_CONCURRENCY_FACTOR", None)
@pytest.mark.skipif(
sys.platform == "win32",
reason="pending support for asyncio on windows",
)
class TestBehaviorAsyncEdgeCases:
"""Edge cases for codeflash_behavior_async."""
@pytest.mark.asyncio
async def test_multiple_calls_increment_index(
self, env_setup, async_db_path
) -> None:
"""Second call to same decorator increments the invocation counter."""
@codeflash_behavior_async
async def inc(x: int) -> int:
return x + 1
_codeflash_call_site.set("0")
await inc(1)
await inc(2)
_close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute(
"SELECT invocation_id FROM async_results ORDER BY invocation_id"
)
ids = [row[0] for row in cur.fetchall()]
assert 2 == len(ids)
assert "0_0" == ids[0]
assert "0_1" == ids[1]
con.close()
@pytest.mark.skipif(
sys.platform == "win32",
reason="pending support for asyncio on windows",
)
class TestPerformanceAsyncEdgeCases:
"""Edge cases for codeflash_performance_async."""
@pytest.mark.asyncio
async def test_exception_handling(self, env_setup, async_db_path) -> None:
"""Re-raises exceptions from the wrapped function."""
@codeflash_performance_async
async def fail() -> None:
raise ValueError("perf boom")
_codeflash_call_site.set("0")
with pytest.raises(ValueError, match="perf boom"):
await fail()
@pytest.mark.asyncio
async def test_multiple_calls_increment_index(
self, env_setup, async_db_path
) -> None:
"""Second call increments the invocation counter."""
@codeflash_performance_async
async def work() -> int:
return 1
_codeflash_call_site.set("0")
await work()
await work()
_close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute(
"SELECT invocation_id FROM async_results ORDER BY invocation_id"
)
ids = [row[0] for row in cur.fetchall()]
assert "0_0" == ids[0]
assert "0_1" == ids[1]
con.close()
class TestCloseAllConnectionsErrorHandling:
"""_close_all_connections handles exceptions gracefully."""
def test_handles_already_closed_connection(self, tmp_path) -> None:
"""Does not raise when a connection is already closed."""
db_path = tmp_path / "test.sqlite"
conn, _ = _get_async_db(db_path)
conn.close()
_close_all_connections()
assert 0 == len(_connections)
class TestExtractTestContextEdgeCases:
"""Edge cases for extract_test_context_from_env."""
def test_raises_when_module_empty(self) -> None:
"""Raises RuntimeError when module is empty string."""
with (
patch.dict(
os.environ,
{
"CODEFLASH_TEST_MODULE": "",
"CODEFLASH_TEST_CLASS": "",
"CODEFLASH_TEST_FUNCTION": "test_fn",
},
),
pytest.raises(RuntimeError),
):
extract_test_context_from_env()
class TestSchemaValidation:
"""Validate the async_results SQLite schema."""
def test_table_columns(self, tmp_path) -> None:
"""async_results table has exactly 13 columns."""
db_path = tmp_path / "schema_test.sqlite"
conn, cur = _get_async_db(db_path)
cur.execute("PRAGMA table_info(async_results)")
columns = cur.fetchall()
expected_names = [
"test_module_path",
"test_class_name",
"test_function_name",
"function_getting_tested",
"loop_index",
"invocation_id",
"mode",
"wall_time_ns",
"return_value",
"verification_type",
"sequential_time_ns",
"concurrent_time_ns",
"concurrency_factor",
]
actual_names = [col[1] for col in columns]
assert expected_names == actual_names
conn.close()
_connections.pop(str(db_path), None)

View file

@ -499,7 +499,7 @@ class TestAsyncCallInstrumenter:
return transformer, tree
def test_instruments_await_call(self) -> None:
"""Adds env var assignment before an awaited target call."""
"""Adds call-site contextvar set before an awaited target call."""
code = textwrap.dedent("""\
async def test_it():
result = await target_func(1, 2)
@ -507,7 +507,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert "_codeflash_call_site.set(" in source
assert transformer.did_instrument is True
def test_skips_non_test_async_functions(self) -> None:
@ -519,7 +519,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" not in source
assert "_codeflash_call_site" not in source
assert transformer.did_instrument is False
def test_skips_non_test_sync_functions(self) -> None:
@ -531,7 +531,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" not in source
assert "_codeflash_call_site" not in source
def test_instruments_sync_test_with_await(self) -> None:
"""Instruments sync test_ functions that contain awaited calls."""
@ -542,7 +542,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert "_codeflash_call_site.set(" in source
assert transformer.did_instrument is True
def test_multiple_awaits_get_incrementing_ids(self) -> None:
@ -569,7 +569,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert "_codeflash_call_site.set(" in source
assert transformer.did_instrument is True
def test_recurses_into_class_body(self) -> None:
@ -582,7 +582,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert "_codeflash_call_site.set(" in source
assert transformer.did_instrument is True
def test_no_match_when_position_wrong(self) -> None:
@ -608,7 +608,7 @@ class TestAsyncCallInstrumenter:
transformer, tree = self._make_transformer(code)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "CODEFLASH_CURRENT_LINE_ID" in source
assert "_codeflash_call_site.set(" in source
def test_ignores_non_target_awaits(self) -> None:
"""Does not instrument awaits of unrelated functions."""
@ -1083,7 +1083,10 @@ class TestInjectAsyncProfilingIntoExistingTest:
)
assert ok is True
assert source is not None
assert "import os" in source
assert (
"from codeflash_async_wrapper import _codeflash_call_site"
in source
)
def test_no_instrumentation(self, tmp_path: Path) -> None:
"""Returns (False, None) when test does not call target."""
@ -1148,7 +1151,7 @@ class TestInjectAsyncProfilingIntoExistingTest:
)
assert ok is True
assert source is not None
assert source.count("CODEFLASH_CURRENT_LINE_ID") == 2
assert source.count("_codeflash_call_site.set(") == 2
class TestAddAsyncDecoratorToFunction: