mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
feat: rewrite async instrumentation to use SQLite-only data path and contextvars
Replace the fragile stdout tag protocol with a unified SQLite table (async_results) for all 3 async test modes. The new runtime decorators write behavior, performance, and concurrency results directly to the DB with zero stdout output. Test-file instrumentation now injects _codeflash_call_site.set() (contextvar) instead of os.environ assignments, which is correct for async task isolation. New modules: - runtime/_codeflash_async_decorators.py: self-contained decorators - testing/_async_data_parser.py: SQLite reader replacing stdout parsing Both at 100% test coverage (42 new tests).
This commit is contained in:
parent
24199efc63
commit
629d7f9f08
6 changed files with 1727 additions and 31 deletions
|
|
@ -0,0 +1,359 @@
|
|||
"""
|
||||
Self-contained async instrumentation decorators for codeflash testing.
|
||||
|
||||
This module is copied to test project directories at runtime, so it
|
||||
must have zero imports from the codeflash_python package.
|
||||
"""
|
||||
|
||||
# ruff: noqa: BLE001
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import contextvars
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
import dill as pickle
|
||||
|
||||
|
||||
class VerificationType(str, Enum):
|
||||
"""Type of correctness verification for captured test data."""
|
||||
|
||||
FUNCTION_CALL = "function_call"
|
||||
INIT_STATE_FTO = "init_state_fto"
|
||||
INIT_STATE_HELPER = "init_state_helper"
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def get_run_tmp_file(file_path: Path) -> Path:
|
||||
"""Return a path inside a persistent per-run temporary directory."""
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory( # type: ignore[attr-defined]
|
||||
prefix="codeflash_"
|
||||
)
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
||||
"""
|
||||
Read test module, class, and function names from env vars.
|
||||
|
||||
Raises RuntimeError when required variables are missing.
|
||||
"""
|
||||
test_module = os.environ["CODEFLASH_TEST_MODULE"]
|
||||
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
|
||||
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
|
||||
if test_module and test_function:
|
||||
return (
|
||||
test_module,
|
||||
test_class or None,
|
||||
test_function,
|
||||
)
|
||||
raise RuntimeError( # noqa: TRY003
|
||||
"Test context environment variables not set" # noqa: EM101
|
||||
)
|
||||
|
||||
|
||||
_codeflash_call_site: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"codeflash_call_site", default=""
|
||||
)
|
||||
|
||||
_CREATE_TABLE_SQL = (
|
||||
"CREATE TABLE IF NOT EXISTS async_results ("
|
||||
"test_module_path TEXT NOT NULL, "
|
||||
"test_class_name TEXT, "
|
||||
"test_function_name TEXT NOT NULL, "
|
||||
"function_getting_tested TEXT NOT NULL, "
|
||||
"loop_index INTEGER NOT NULL, "
|
||||
"invocation_id TEXT NOT NULL, "
|
||||
"mode TEXT NOT NULL, "
|
||||
"wall_time_ns INTEGER NOT NULL, "
|
||||
"return_value BLOB, "
|
||||
"verification_type TEXT, "
|
||||
"sequential_time_ns INTEGER, "
|
||||
"concurrent_time_ns INTEGER, "
|
||||
"concurrency_factor INTEGER"
|
||||
")"
|
||||
)
|
||||
|
||||
_connections: dict[str, sqlite3.Connection] = {}
|
||||
|
||||
|
||||
def _get_async_db(
|
||||
db_path: Path,
|
||||
) -> tuple[sqlite3.Connection, sqlite3.Cursor]:
|
||||
"""Get or create a cached SQLite connection."""
|
||||
key = str(db_path)
|
||||
if key not in _connections:
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(_CREATE_TABLE_SQL)
|
||||
_connections[key] = conn
|
||||
conn = _connections[key]
|
||||
return conn, conn.cursor()
|
||||
|
||||
|
||||
def _close_all_connections() -> None:
|
||||
"""Commit and close all cached connections."""
|
||||
for conn in _connections.values():
|
||||
try:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception: # noqa: PERF203, S110
|
||||
pass
|
||||
_connections.clear()
|
||||
|
||||
|
||||
atexit.register(_close_all_connections)
|
||||
|
||||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
"""
|
||||
Capture async return values and timing for behavioral tests.
|
||||
|
||||
Results are written to the async_results SQLite table.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
call_site = _codeflash_call_site.get()
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
) = extract_test_context_from_env()
|
||||
|
||||
test_id = (
|
||||
f"{test_module_name}:{test_class_name}"
|
||||
f":{test_name}:{call_site}:{loop_index}"
|
||||
)
|
||||
|
||||
if not hasattr(wrapper, "index"):
|
||||
wrapper.index = {} # type: ignore[attr-defined]
|
||||
if test_id in wrapper.index: # type: ignore[attr-defined]
|
||||
wrapper.index[test_id] += 1 # type: ignore[attr-defined]
|
||||
else:
|
||||
wrapper.index[test_id] = 0 # type: ignore[attr-defined]
|
||||
|
||||
call_index = wrapper.index[test_id] # type: ignore[attr-defined]
|
||||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
|
||||
conn, cur = _get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
wall_time = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
wall_time = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
pickled = (
|
||||
pickle.dumps(exception)
|
||||
if exception
|
||||
else pickle.dumps((args, kwargs, return_value))
|
||||
)
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
"behavior",
|
||||
wall_time,
|
||||
pickled,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def codeflash_performance_async(func: F) -> F:
|
||||
"""
|
||||
Measure async execution time for performance tests.
|
||||
|
||||
Results are written to the async_results SQLite table.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
call_site = _codeflash_call_site.get()
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
) = extract_test_context_from_env()
|
||||
|
||||
test_id = (
|
||||
f"{test_module_name}:{test_class_name}"
|
||||
f":{test_name}:{call_site}:{loop_index}"
|
||||
)
|
||||
|
||||
if not hasattr(wrapper, "index"):
|
||||
wrapper.index = {} # type: ignore[attr-defined]
|
||||
if test_id in wrapper.index: # type: ignore[attr-defined]
|
||||
wrapper.index[test_id] += 1 # type: ignore[attr-defined]
|
||||
else:
|
||||
wrapper.index[test_id] = 0 # type: ignore[attr-defined]
|
||||
|
||||
call_index = wrapper.index[test_id] # type: ignore[attr-defined]
|
||||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
|
||||
conn, cur = _get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
wall_time = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
wall_time = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
func.__name__,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
"performance",
|
||||
wall_time,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func: F) -> F:
|
||||
"""
|
||||
Measure concurrent vs sequential execution for async functions.
|
||||
|
||||
Results are written to the async_results SQLite table.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(
|
||||
os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")
|
||||
)
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
) = extract_test_context_from_env()
|
||||
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "0"))
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
|
||||
conn, cur = _get_async_db(db_path)
|
||||
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
"",
|
||||
"concurrency",
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
sequential_time,
|
||||
concurrent_time,
|
||||
concurrency_factor,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return result
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VerificationType",
|
||||
"_codeflash_call_site",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_concurrency_async",
|
||||
"codeflash_performance_async",
|
||||
"extract_test_context_from_env",
|
||||
"get_run_tmp_file",
|
||||
]
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
"""Async test result parsing from SQLite databases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .._model import VerificationType
|
||||
from ..test_discovery.models import TestType
|
||||
from ._path_resolution import file_path_from_module_name
|
||||
from .models import FunctionTestInvocation, InvocationId, TestResults
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from ..benchmarking.models import ConcurrencyMetrics
|
||||
from .models import TestConfig, TestFiles
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_BEHAVIOR_QUERY = (
|
||||
"SELECT test_module_path, test_class_name,"
|
||||
" test_function_name, function_getting_tested,"
|
||||
" loop_index, invocation_id, wall_time_ns,"
|
||||
" return_value, verification_type"
|
||||
" FROM async_results"
|
||||
" WHERE mode = 'behavior'"
|
||||
)
|
||||
|
||||
_THROUGHPUT_QUERY = (
|
||||
"SELECT COUNT(*) FROM async_results"
|
||||
" WHERE function_getting_tested = ?"
|
||||
" AND mode = 'performance'"
|
||||
)
|
||||
|
||||
_CONCURRENCY_QUERY = (
|
||||
"SELECT sequential_time_ns, concurrent_time_ns,"
|
||||
" concurrency_factor"
|
||||
" FROM async_results"
|
||||
" WHERE function_getting_tested = ?"
|
||||
" AND mode = 'concurrency'"
|
||||
)
|
||||
|
||||
|
||||
def parse_async_behavior_results(
|
||||
sqlite_file_path: Path,
|
||||
test_files: TestFiles,
|
||||
test_config: TestConfig,
|
||||
) -> TestResults:
|
||||
"""Parse behavior test results from the async results database."""
|
||||
test_results = TestResults()
|
||||
if not sqlite_file_path.exists():
|
||||
log.warning(
|
||||
"No async test results for %s found.",
|
||||
sqlite_file_path,
|
||||
)
|
||||
return test_results
|
||||
|
||||
db: sqlite3.Connection | None = None
|
||||
try:
|
||||
db = sqlite3.connect(sqlite_file_path)
|
||||
cur = db.cursor()
|
||||
data = cur.execute(_BEHAVIOR_QUERY).fetchall()
|
||||
except (sqlite3.Error, OSError):
|
||||
log.warning(
|
||||
"Failed to parse async test results from %s.",
|
||||
sqlite_file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
if db is not None:
|
||||
db.close()
|
||||
return test_results
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
for row in data:
|
||||
_process_behavior_row(row, test_files, test_config, test_results)
|
||||
|
||||
return test_results
|
||||
|
||||
|
||||
def _process_behavior_row(
|
||||
val: tuple[object, ...],
|
||||
test_files: TestFiles,
|
||||
test_config: TestConfig,
|
||||
test_results: TestResults,
|
||||
) -> None:
|
||||
"""Process a single behavior row from the async results table."""
|
||||
try:
|
||||
_process_behavior_row_inner(val, test_files, test_config, test_results)
|
||||
except Exception:
|
||||
log.exception("Failed to parse async behavior result")
|
||||
|
||||
|
||||
def _process_behavior_row_inner(
|
||||
val: tuple[object, ...],
|
||||
test_files: TestFiles,
|
||||
test_config: TestConfig,
|
||||
test_results: TestResults,
|
||||
) -> None:
|
||||
"""Inner processing for a single async behavior row."""
|
||||
test_module_path = val[0]
|
||||
test_class_name = val[1] or None
|
||||
test_function_name = val[2] or None
|
||||
function_getting_tested = val[3]
|
||||
loop_index = val[4]
|
||||
invocation_id = val[5]
|
||||
wall_time_ns = val[6]
|
||||
verification_type = val[8]
|
||||
|
||||
test_file_path = file_path_from_module_name(
|
||||
test_module_path, # type: ignore[arg-type]
|
||||
test_config.tests_project_rootdir,
|
||||
)
|
||||
|
||||
if verification_type in {
|
||||
VerificationType.INIT_STATE_FTO,
|
||||
VerificationType.INIT_STATE_HELPER,
|
||||
}:
|
||||
test_type: TestType = TestType.INIT_STATE_TEST
|
||||
else:
|
||||
found = test_files.get_test_type_by_original_file_path(
|
||||
test_file_path,
|
||||
)
|
||||
if found is None:
|
||||
found = test_files.get_test_type_by_instrumented_file_path(
|
||||
test_file_path,
|
||||
)
|
||||
if found is None:
|
||||
log.debug(
|
||||
"Skipping async result for %s: could not determine test type",
|
||||
test_function_name,
|
||||
)
|
||||
return
|
||||
test_type = found
|
||||
|
||||
ret_val = None
|
||||
if loop_index == 1 and val[7]:
|
||||
import dill as pickle # noqa: PLC0415
|
||||
|
||||
try:
|
||||
ret_val = (pickle.loads(val[7]),) # noqa: S301
|
||||
except Exception: # noqa: BLE001
|
||||
log.debug(
|
||||
"Failed to deserialize return value for %s",
|
||||
test_function_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
test_results.add(
|
||||
FunctionTestInvocation(
|
||||
loop_index=loop_index, # type: ignore[arg-type]
|
||||
id=InvocationId(
|
||||
test_module_path=test_module_path, # type: ignore[arg-type]
|
||||
test_class_name=test_class_name, # type: ignore[arg-type]
|
||||
test_function_name=test_function_name, # type: ignore[arg-type]
|
||||
function_getting_tested=function_getting_tested, # type: ignore[arg-type]
|
||||
iteration_id=invocation_id, # type: ignore[arg-type]
|
||||
),
|
||||
file_name=test_file_path,
|
||||
did_pass=True,
|
||||
runtime=wall_time_ns, # type: ignore[arg-type]
|
||||
test_framework=test_config.test_framework,
|
||||
test_type=test_type,
|
||||
return_value=ret_val,
|
||||
cpu_runtime=0,
|
||||
timed_out=False,
|
||||
verification_type=(
|
||||
VerificationType(verification_type)
|
||||
if verification_type
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def calculate_async_throughput(
|
||||
sqlite_file_path: Path,
|
||||
function_name: str,
|
||||
) -> int:
|
||||
"""Count completed async function executions from the results database."""
|
||||
if not sqlite_file_path.exists():
|
||||
return 0
|
||||
|
||||
db: sqlite3.Connection | None = None
|
||||
try:
|
||||
db = sqlite3.connect(sqlite_file_path)
|
||||
cur = db.cursor()
|
||||
row = cur.execute(_THROUGHPUT_QUERY, (function_name,)).fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
except (sqlite3.Error, OSError):
|
||||
log.warning(
|
||||
"Failed to read async throughput from %s.",
|
||||
sqlite_file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
return 0
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
|
||||
def parse_async_concurrency_metrics(
|
||||
sqlite_file_path: Path,
|
||||
function_name: str,
|
||||
) -> ConcurrencyMetrics | None:
|
||||
"""Parse concurrency benchmark metrics from the results database."""
|
||||
if not sqlite_file_path.exists():
|
||||
return None
|
||||
|
||||
db: sqlite3.Connection | None = None
|
||||
try:
|
||||
db = sqlite3.connect(sqlite_file_path)
|
||||
cur = db.cursor()
|
||||
rows = cur.execute(_CONCURRENCY_QUERY, (function_name,)).fetchall()
|
||||
except (sqlite3.Error, OSError):
|
||||
log.warning(
|
||||
"Failed to read async concurrency metrics from %s.",
|
||||
sqlite_file_path,
|
||||
exc_info=True,
|
||||
)
|
||||
if db is not None:
|
||||
db.close()
|
||||
return None
|
||||
finally:
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
total_seq = 0
|
||||
total_conc = 0
|
||||
factor = rows[0][2]
|
||||
for row in rows:
|
||||
total_seq += row[0]
|
||||
total_conc += row[1]
|
||||
|
||||
count = len(rows)
|
||||
avg_seq = total_seq / count
|
||||
avg_conc = total_conc / count
|
||||
|
||||
ratio = 1.0 if avg_conc == 0 else avg_seq / avg_conc
|
||||
|
||||
from ..benchmarking.models import ( # noqa: PLC0415
|
||||
ConcurrencyMetrics,
|
||||
)
|
||||
|
||||
return ConcurrencyMetrics(
|
||||
sequential_time_ns=int(avg_seq),
|
||||
concurrent_time_ns=int(avg_conc),
|
||||
concurrency_factor=factor,
|
||||
concurrency_ratio=ratio,
|
||||
)
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
"""Async-specific instrumentation: AST transformers, decorators, and helpers.
|
||||
|
||||
Provides ``AsyncCallInstrumenter`` for injecting ``CODEFLASH_CURRENT_LINE_ID``
|
||||
assignments before ``await`` calls, ``AsyncDecoratorAdder`` for adding
|
||||
async performance/behavior decorators via libcst, and high-level functions
|
||||
for instrumenting async test and source files.
|
||||
Provides ``AsyncCallInstrumenter`` for injecting ``_codeflash_call_site``
|
||||
contextvar assignments before ``await`` calls, ``AsyncDecoratorAdder`` for
|
||||
adding async performance/behavior decorators via libcst, and high-level
|
||||
functions for instrumenting async test and source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -71,8 +71,7 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
self,
|
||||
node: ast.AsyncFunctionDef | ast.FunctionDef,
|
||||
) -> ast.AsyncFunctionDef | ast.FunctionDef:
|
||||
"""Add CODEFLASH_CURRENT_LINE_ID assignments before target await calls."""
|
||||
# Initialize counter for this test function
|
||||
"""Add _codeflash_call_site.set() calls before target await calls."""
|
||||
if node.name not in self.async_call_counter:
|
||||
self.async_call_counter[node.name] = 0
|
||||
|
||||
|
|
@ -85,24 +84,26 @@ class AsyncCallInstrumenter(ast.NodeTransformer):
|
|||
current_call_index = self.async_call_counter[node.name]
|
||||
self.async_call_counter[node.name] += 1
|
||||
|
||||
env_assignment = ast.Assign(
|
||||
targets=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()),
|
||||
attr="environ",
|
||||
call_site_set = ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(
|
||||
id="_codeflash_call_site",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
slice=ast.Constant(
|
||||
value="CODEFLASH_CURRENT_LINE_ID"
|
||||
attr="set",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[
|
||||
ast.Constant(
|
||||
value=f"{current_call_index}",
|
||||
),
|
||||
ctx=ast.Store(),
|
||||
)
|
||||
],
|
||||
value=ast.Constant(value=f"{current_call_index}"),
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=stmt.lineno,
|
||||
)
|
||||
new_body.append(env_assignment)
|
||||
new_body.append(call_site_set)
|
||||
self.did_instrument = True
|
||||
|
||||
new_body.append(stmt)
|
||||
|
|
@ -236,7 +237,7 @@ ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
|
|||
_RUNTIME_DECORATOR_PATH = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "runtime"
|
||||
/ "_codeflash_wrap_decorator.py"
|
||||
/ "_codeflash_async_decorators.py"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -289,7 +290,13 @@ def inject_async_profiling_into_existing_test(
|
|||
if not async_instrumenter.did_instrument:
|
||||
return False, None
|
||||
|
||||
new_imports = [ast.Import(names=[ast.alias(name="os")])]
|
||||
new_imports = [
|
||||
ast.ImportFrom(
|
||||
module="codeflash_async_wrapper",
|
||||
names=[ast.alias(name="_codeflash_call_site")],
|
||||
level=0,
|
||||
),
|
||||
]
|
||||
tree.body = [*new_imports, *tree.body]
|
||||
return True, sort_imports(ast.unparse(tree), float_to_top=True)
|
||||
|
||||
|
|
|
|||
562
packages/codeflash-python/tests/test_async_data_parser.py
Normal file
562
packages/codeflash-python/tests/test_async_data_parser.py
Normal file
|
|
@ -0,0 +1,562 @@
|
|||
"""Tests for the async SQLite data parser."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dill as pickle
|
||||
import pytest
|
||||
|
||||
from codeflash_python.runtime._codeflash_async_decorators import (
|
||||
_CREATE_TABLE_SQL,
|
||||
VerificationType,
|
||||
)
|
||||
from codeflash_python.testing._async_data_parser import (
|
||||
calculate_async_throughput,
|
||||
parse_async_behavior_results,
|
||||
parse_async_concurrency_metrics,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="test_config")
|
||||
def _test_config(tmp_path):
|
||||
"""Minimal TestConfig mock."""
|
||||
cfg = MagicMock()
|
||||
cfg.tests_project_rootdir = tmp_path
|
||||
cfg.test_framework = "pytest"
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="test_files")
|
||||
def _test_files(tmp_path):
|
||||
"""Minimal TestFiles mock."""
|
||||
from codeflash_python.test_discovery.models import TestType
|
||||
|
||||
tf = MagicMock()
|
||||
tf.get_test_type_by_original_file_path.return_value = (
|
||||
TestType.EXISTING_UNIT_TEST
|
||||
)
|
||||
tf.get_test_type_by_instrumented_file_path.return_value = None
|
||||
return tf
|
||||
|
||||
|
||||
def _create_async_db(
|
||||
db_path: Path,
|
||||
rows: list[tuple],
|
||||
) -> None:
|
||||
"""Create an async_results SQLite DB with the given rows."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(_CREATE_TABLE_SQL)
|
||||
for row in rows:
|
||||
conn.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
row,
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
class TestParseAsyncBehaviorResults:
|
||||
"""parse_async_behavior_results reads behavior rows from SQLite."""
|
||||
|
||||
def test_empty_when_file_missing(
|
||||
self, tmp_path, test_files, test_config
|
||||
) -> None:
|
||||
"""Returns empty TestResults when file does not exist."""
|
||||
results = parse_async_behavior_results(
|
||||
tmp_path / "nonexistent.sqlite",
|
||||
test_files,
|
||||
test_config,
|
||||
)
|
||||
assert 0 == len(list(results))
|
||||
|
||||
def test_reads_behavior_rows(
|
||||
self, tmp_path, test_files, test_config
|
||||
) -> None:
|
||||
"""Parses behavior rows into FunctionTestInvocation objects."""
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
module_path = "test_module"
|
||||
pickled = pickle.dumps(((1, 2), {}, 3))
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
module_path,
|
||||
None,
|
||||
"test_fn",
|
||||
"target_func",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
1000,
|
||||
pickled,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(
|
||||
db_path, test_files, test_config
|
||||
)
|
||||
invocations = list(results)
|
||||
assert 1 == len(invocations)
|
||||
inv = invocations[0]
|
||||
assert "target_func" == inv.id.function_getting_tested
|
||||
assert 1000 == inv.runtime
|
||||
assert inv.return_value is not None
|
||||
|
||||
def test_skips_non_behavior_rows(
|
||||
self, tmp_path, test_files, test_config
|
||||
) -> None:
|
||||
"""Ignores performance and concurrency rows."""
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"test_module",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"performance",
|
||||
1000,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(
|
||||
db_path, test_files, test_config
|
||||
)
|
||||
assert 0 == len(list(results))
|
||||
|
||||
def test_handles_corrupt_db(
|
||||
self, tmp_path, test_files, test_config
|
||||
) -> None:
|
||||
"""Returns empty results for a corrupted database file."""
|
||||
db_path = tmp_path / "corrupt.sqlite"
|
||||
db_path.write_text("not a sqlite file", encoding="utf-8")
|
||||
results = parse_async_behavior_results(
|
||||
db_path, test_files, test_config
|
||||
)
|
||||
assert 0 == len(list(results))
|
||||
|
||||
|
||||
class TestCalculateAsyncThroughput:
|
||||
"""calculate_async_throughput counts performance rows."""
|
||||
|
||||
def test_returns_zero_when_file_missing(self, tmp_path) -> None:
|
||||
"""Returns 0 when SQLite file does not exist."""
|
||||
result = calculate_async_throughput(
|
||||
tmp_path / "nonexistent.sqlite",
|
||||
"func",
|
||||
)
|
||||
assert 0 == result
|
||||
|
||||
def test_counts_performance_rows(self, tmp_path) -> None:
|
||||
"""Counts rows matching function name and performance mode."""
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"performance",
|
||||
1000,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_1",
|
||||
"performance",
|
||||
2000,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"other_func",
|
||||
1,
|
||||
"0_0",
|
||||
"performance",
|
||||
3000,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
4000,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
result = calculate_async_throughput(db_path, "target")
|
||||
assert 2 == result
|
||||
|
||||
def test_returns_zero_for_no_matches(self, tmp_path) -> None:
|
||||
"""Returns 0 when no rows match the function name."""
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
_create_async_db(db_path, [])
|
||||
result = calculate_async_throughput(db_path, "nonexistent")
|
||||
assert 0 == result
|
||||
|
||||
|
||||
class TestParseAsyncConcurrencyMetrics:
|
||||
"""parse_async_concurrency_metrics reads concurrency rows."""
|
||||
|
||||
def test_returns_none_when_file_missing(self, tmp_path) -> None:
|
||||
"""Returns None when SQLite file does not exist."""
|
||||
result = parse_async_concurrency_metrics(
|
||||
tmp_path / "nonexistent.sqlite",
|
||||
"func",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_parses_concurrency_metrics(self, tmp_path) -> None:
|
||||
"""Computes averages from multiple concurrency rows."""
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"",
|
||||
"concurrency",
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
100_000,
|
||||
50_000,
|
||||
10,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"",
|
||||
"concurrency",
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
200_000,
|
||||
100_000,
|
||||
10,
|
||||
),
|
||||
],
|
||||
)
|
||||
metrics = parse_async_concurrency_metrics(db_path, "target")
|
||||
assert metrics is not None
|
||||
assert 150_000 == metrics.sequential_time_ns
|
||||
assert 75_000 == metrics.concurrent_time_ns
|
||||
assert 10 == metrics.concurrency_factor
|
||||
assert 2.0 == metrics.concurrency_ratio
|
||||
|
||||
def test_returns_none_for_wrong_function(self, tmp_path) -> None:
|
||||
"""Returns None when no rows match the function name."""
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"other_func",
|
||||
1,
|
||||
"",
|
||||
"concurrency",
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
100_000,
|
||||
50_000,
|
||||
10,
|
||||
),
|
||||
],
|
||||
)
|
||||
result = parse_async_concurrency_metrics(db_path, "target")
|
||||
assert result is None
|
||||
|
||||
def test_handles_zero_concurrent_time(self, tmp_path) -> None:
|
||||
"""Returns ratio 1.0 when concurrent time is zero."""
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"mod",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"",
|
||||
"concurrency",
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
100_000,
|
||||
0,
|
||||
5,
|
||||
),
|
||||
],
|
||||
)
|
||||
metrics = parse_async_concurrency_metrics(db_path, "target")
|
||||
assert metrics is not None
|
||||
assert 1.0 == metrics.concurrency_ratio
|
||||
|
||||
def test_corrupt_db_returns_none(self, tmp_path) -> None:
|
||||
"""Returns None for a corrupted database file."""
|
||||
db_path = tmp_path / "corrupt.sqlite"
|
||||
db_path.write_text("not a sqlite file", encoding="utf-8")
|
||||
result = parse_async_concurrency_metrics(db_path, "target")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCalculateAsyncThroughputErrorHandling:
|
||||
"""Error handling in calculate_async_throughput."""
|
||||
|
||||
def test_corrupt_db_returns_zero(self, tmp_path) -> None:
|
||||
"""Returns 0 for a corrupted database file."""
|
||||
db_path = tmp_path / "corrupt.sqlite"
|
||||
db_path.write_text("not a sqlite file", encoding="utf-8")
|
||||
result = calculate_async_throughput(db_path, "func")
|
||||
assert 0 == result
|
||||
|
||||
|
||||
class TestParseAsyncBehaviorEdgeCases:
|
||||
"""Edge cases for parse_async_behavior_results."""
|
||||
|
||||
def test_init_state_verification_type(
|
||||
self, tmp_path, test_files, test_config
|
||||
) -> None:
|
||||
"""Rows with INIT_STATE_FTO get INIT_STATE_TEST type."""
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"test_module",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
1000,
|
||||
pickle.dumps(((1,), {}, 2)),
|
||||
VerificationType.INIT_STATE_FTO.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(
|
||||
db_path, test_files, test_config
|
||||
)
|
||||
invocations = list(results)
|
||||
assert 1 == len(invocations)
|
||||
|
||||
def test_test_type_fallback_to_instrumented(
|
||||
self, tmp_path, test_config
|
||||
) -> None:
|
||||
"""Falls back to instrumented file path for test type lookup."""
|
||||
from codeflash_python.test_discovery.models import TestType
|
||||
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
|
||||
tf = MagicMock()
|
||||
tf.get_test_type_by_original_file_path.return_value = None
|
||||
tf.get_test_type_by_instrumented_file_path.return_value = (
|
||||
TestType.EXISTING_UNIT_TEST
|
||||
)
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"test_module",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
1000,
|
||||
pickle.dumps(((1,), {}, 2)),
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(db_path, tf, test_config)
|
||||
assert 1 == len(list(results))
|
||||
|
||||
def test_test_type_not_found_skips_row(
|
||||
self, tmp_path, test_config
|
||||
) -> None:
|
||||
"""Skips rows when test type cannot be determined."""
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
|
||||
tf = MagicMock()
|
||||
tf.get_test_type_by_original_file_path.return_value = None
|
||||
tf.get_test_type_by_instrumented_file_path.return_value = None
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"test_module",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
1000,
|
||||
pickle.dumps(((1,), {}, 2)),
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(db_path, tf, test_config)
|
||||
assert 0 == len(list(results))
|
||||
|
||||
def test_bad_pickle_data_skips_row(
|
||||
self, tmp_path, test_files, test_config
|
||||
) -> None:
|
||||
"""Skips rows with unpicklable return value data."""
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"test_module",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
1000,
|
||||
b"definitely not valid pickle",
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(
|
||||
db_path, test_files, test_config
|
||||
)
|
||||
assert 0 == len(list(results))
|
||||
|
||||
def test_outer_exception_handler(self, tmp_path, test_config) -> None:
|
||||
"""Outer exception handler catches errors from inner processing."""
|
||||
db_path = tmp_path / "test_module.py"
|
||||
db_path.write_text("", encoding="utf-8")
|
||||
db_path = tmp_path / "async_results.sqlite"
|
||||
|
||||
tf = MagicMock()
|
||||
tf.get_test_type_by_original_file_path.side_effect = RuntimeError(
|
||||
"boom"
|
||||
)
|
||||
|
||||
_create_async_db(
|
||||
db_path,
|
||||
[
|
||||
(
|
||||
"test_module",
|
||||
None,
|
||||
"test_fn",
|
||||
"target",
|
||||
1,
|
||||
"0_0",
|
||||
"behavior",
|
||||
1000,
|
||||
None,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
results = parse_async_behavior_results(db_path, tf, test_config)
|
||||
assert 0 == len(list(results))
|
||||
508
packages/codeflash-python/tests/test_async_decorators.py
Normal file
508
packages/codeflash-python/tests/test_async_decorators.py
Normal file
|
|
@ -0,0 +1,508 @@
|
|||
"""Tests for the new self-contained async instrumentation decorators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import dill as pickle
|
||||
import pytest
|
||||
|
||||
from codeflash_python.runtime._codeflash_async_decorators import (
|
||||
VerificationType,
|
||||
_close_all_connections,
|
||||
_codeflash_call_site,
|
||||
_connections,
|
||||
_get_async_db,
|
||||
codeflash_behavior_async,
|
||||
codeflash_concurrency_async,
|
||||
codeflash_performance_async,
|
||||
extract_test_context_from_env,
|
||||
get_run_tmp_file,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="env_setup")
|
||||
def _env_setup(request, tmp_path):
|
||||
"""Set up env vars and clean up after each test."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "0",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
}
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
_close_all_connections()
|
||||
|
||||
|
||||
@pytest.fixture(name="async_db_path")
|
||||
def _async_db_path(env_setup):
|
||||
"""Return the path where the async results DB will be written."""
|
||||
iteration = env_setup["CODEFLASH_TEST_ITERATION"]
|
||||
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
|
||||
yield db_path
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
|
||||
class TestExtractTestContextFromEnv:
|
||||
"""extract_test_context_from_env reads CODEFLASH_TEST_* env vars."""
|
||||
|
||||
def test_returns_tuple(self) -> None:
|
||||
"""Returns (module, class_or_none, function) from env."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"CODEFLASH_TEST_MODULE": "mod",
|
||||
"CODEFLASH_TEST_CLASS": "Cls",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_fn",
|
||||
},
|
||||
):
|
||||
module, cls, fn = extract_test_context_from_env()
|
||||
assert "mod" == module
|
||||
assert "Cls" == cls
|
||||
assert "test_fn" == fn
|
||||
|
||||
def test_class_none_when_empty(self) -> None:
|
||||
"""Returns None for test class when env var is empty."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"CODEFLASH_TEST_MODULE": "mod",
|
||||
"CODEFLASH_TEST_CLASS": "",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_fn",
|
||||
},
|
||||
):
|
||||
_, cls, _ = extract_test_context_from_env()
|
||||
assert cls is None
|
||||
|
||||
def test_raises_when_module_missing(self) -> None:
|
||||
"""Raises KeyError when CODEFLASH_TEST_MODULE is unset."""
|
||||
env = os.environ.copy()
|
||||
env.pop("CODEFLASH_TEST_MODULE", None)
|
||||
with patch.dict(os.environ, env, clear=True), pytest.raises(KeyError):
|
||||
extract_test_context_from_env()
|
||||
|
||||
|
||||
class TestCodeflashCallSite:
|
||||
"""_codeflash_call_site contextvar behavior."""
|
||||
|
||||
def test_default_value(self) -> None:
|
||||
"""Default value is empty string."""
|
||||
ctx = contextvars.copy_context()
|
||||
assert "" == ctx.run(_codeflash_call_site.get)
|
||||
|
||||
def test_set_and_get(self) -> None:
|
||||
"""Can set and retrieve a value."""
|
||||
token = _codeflash_call_site.set("42")
|
||||
assert "42" == _codeflash_call_site.get()
|
||||
_codeflash_call_site.reset(token)
|
||||
|
||||
|
||||
class TestGetAsyncDb:
|
||||
"""_get_async_db connection caching."""
|
||||
|
||||
def test_creates_table(self, tmp_path) -> None:
|
||||
"""Creates the async_results table on first connect."""
|
||||
db_path = tmp_path / "test.sqlite"
|
||||
conn, cur = _get_async_db(db_path)
|
||||
cur.execute(
|
||||
"SELECT name FROM sqlite_master "
|
||||
"WHERE type='table' AND name='async_results'"
|
||||
)
|
||||
assert cur.fetchone() is not None
|
||||
conn.close()
|
||||
_connections.pop(str(db_path), None)
|
||||
|
||||
def test_caches_connection(self, tmp_path) -> None:
|
||||
"""Returns the same connection on repeated calls."""
|
||||
db_path = tmp_path / "test.sqlite"
|
||||
conn1, _ = _get_async_db(db_path)
|
||||
conn2, _ = _get_async_db(db_path)
|
||||
assert conn1 is conn2
|
||||
conn1.close()
|
||||
_connections.pop(str(db_path), None)
|
||||
|
||||
def test_close_all_connections(self, tmp_path) -> None:
|
||||
"""_close_all_connections empties the cache."""
|
||||
db_path = tmp_path / "test.sqlite"
|
||||
_get_async_db(db_path)
|
||||
assert 0 < len(_connections)
|
||||
_close_all_connections()
|
||||
assert 0 == len(_connections)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="pending support for asyncio on windows",
|
||||
)
|
||||
class TestBehaviorAsync:
|
||||
"""codeflash_behavior_async decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_value(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Decorated function returns the original return value."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
result = await add(3, 4)
|
||||
assert 7 == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_to_sqlite(self, env_setup, async_db_path) -> None:
|
||||
"""Writes behavior result to async_results table."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def multiply(a: int, b: int) -> int:
|
||||
return a * b
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await multiply(5, 6)
|
||||
_close_all_connections()
|
||||
|
||||
assert async_db_path.exists()
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM async_results")
|
||||
rows = cur.fetchall()
|
||||
assert 1 == len(rows)
|
||||
row = rows[0]
|
||||
assert "behavior" == row[6]
|
||||
assert 0 < row[7]
|
||||
|
||||
data = pickle.loads(row[8])
|
||||
args, kwargs, ret = data
|
||||
assert (5, 6) == args
|
||||
assert {} == kwargs
|
||||
assert 30 == ret
|
||||
assert VerificationType.FUNCTION_CALL.value == row[9]
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self, env_setup, async_db_path) -> None:
|
||||
"""Re-raises exceptions and stores them pickled."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def fail() -> None:
|
||||
raise ValueError("boom")
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
await fail()
|
||||
|
||||
_close_all_connections()
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM async_results")
|
||||
row = cur.fetchone()
|
||||
exc = pickle.loads(row[0])
|
||||
assert isinstance(exc, ValueError)
|
||||
assert "boom" == str(exc)
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_stdout_output(
|
||||
self, env_setup, async_db_path, capsys
|
||||
) -> None:
|
||||
"""Behavior decorator emits no stdout."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def noop() -> int:
|
||||
return 1
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await noop()
|
||||
captured = capsys.readouterr()
|
||||
assert "" == captured.out
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="pending support for asyncio on windows",
|
||||
)
|
||||
class TestPerformanceAsync:
|
||||
"""codeflash_performance_async decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_value(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Returns the original return value."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
result = await add(3, 4)
|
||||
assert 7 == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_to_sqlite(self, env_setup, async_db_path) -> None:
|
||||
"""Writes performance result with null return_value."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def work() -> int:
|
||||
return 42
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await work()
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM async_results")
|
||||
rows = cur.fetchall()
|
||||
assert 1 == len(rows)
|
||||
row = rows[0]
|
||||
assert "performance" == row[6]
|
||||
assert 0 < row[7]
|
||||
assert row[8] is None
|
||||
assert row[9] is None
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_stdout_output(
|
||||
self, env_setup, async_db_path, capsys
|
||||
) -> None:
|
||||
"""Performance decorator emits no stdout."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def noop() -> int:
|
||||
return 1
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await noop()
|
||||
captured = capsys.readouterr()
|
||||
assert "" == captured.out
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="pending support for asyncio on windows",
|
||||
)
|
||||
class TestConcurrencyAsync:
|
||||
"""codeflash_concurrency_async decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_correct_value(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Returns the result from sequential execution."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def add(a: int, b: int) -> int:
|
||||
return a + b
|
||||
|
||||
result = await add(3, 4)
|
||||
assert 7 == result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_concurrency_metrics(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Writes sequential/concurrent timing to async_results."""
|
||||
os.environ["CODEFLASH_CONCURRENCY_FACTOR"] = "3"
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def work() -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return 42
|
||||
|
||||
await work()
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM async_results")
|
||||
rows = cur.fetchall()
|
||||
assert 1 == len(rows)
|
||||
row = rows[0]
|
||||
assert "concurrency" == row[6]
|
||||
assert 0 == row[7]
|
||||
assert 0 < row[10]
|
||||
assert 0 < row[11]
|
||||
assert 3 == row[12]
|
||||
con.close()
|
||||
|
||||
os.environ.pop("CODEFLASH_CONCURRENCY_FACTOR", None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_stdout_output(
|
||||
self, env_setup, async_db_path, capsys
|
||||
) -> None:
|
||||
"""Concurrency decorator emits no stdout."""
|
||||
os.environ["CODEFLASH_CONCURRENCY_FACTOR"] = "2"
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def noop() -> int:
|
||||
return 1
|
||||
|
||||
await noop()
|
||||
captured = capsys.readouterr()
|
||||
assert "" == captured.out
|
||||
|
||||
os.environ.pop("CODEFLASH_CONCURRENCY_FACTOR", None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="pending support for asyncio on windows",
|
||||
)
|
||||
class TestBehaviorAsyncEdgeCases:
|
||||
"""Edge cases for codeflash_behavior_async."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_increment_index(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Second call to same decorator increments the invocation counter."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def inc(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await inc(1)
|
||||
await inc(2)
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute(
|
||||
"SELECT invocation_id FROM async_results ORDER BY invocation_id"
|
||||
)
|
||||
ids = [row[0] for row in cur.fetchall()]
|
||||
assert 2 == len(ids)
|
||||
assert "0_0" == ids[0]
|
||||
assert "0_1" == ids[1]
|
||||
con.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="pending support for asyncio on windows",
|
||||
)
|
||||
class TestPerformanceAsyncEdgeCases:
|
||||
"""Edge cases for codeflash_performance_async."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self, env_setup, async_db_path) -> None:
|
||||
"""Re-raises exceptions from the wrapped function."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def fail() -> None:
|
||||
raise ValueError("perf boom")
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
with pytest.raises(ValueError, match="perf boom"):
|
||||
await fail()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_increment_index(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Second call increments the invocation counter."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def work() -> int:
|
||||
return 1
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await work()
|
||||
await work()
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute(
|
||||
"SELECT invocation_id FROM async_results ORDER BY invocation_id"
|
||||
)
|
||||
ids = [row[0] for row in cur.fetchall()]
|
||||
assert "0_0" == ids[0]
|
||||
assert "0_1" == ids[1]
|
||||
con.close()
|
||||
|
||||
|
||||
class TestCloseAllConnectionsErrorHandling:
|
||||
"""_close_all_connections handles exceptions gracefully."""
|
||||
|
||||
def test_handles_already_closed_connection(self, tmp_path) -> None:
|
||||
"""Does not raise when a connection is already closed."""
|
||||
db_path = tmp_path / "test.sqlite"
|
||||
conn, _ = _get_async_db(db_path)
|
||||
conn.close()
|
||||
_close_all_connections()
|
||||
assert 0 == len(_connections)
|
||||
|
||||
|
||||
class TestExtractTestContextEdgeCases:
|
||||
"""Edge cases for extract_test_context_from_env."""
|
||||
|
||||
def test_raises_when_module_empty(self) -> None:
|
||||
"""Raises RuntimeError when module is empty string."""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"CODEFLASH_TEST_MODULE": "",
|
||||
"CODEFLASH_TEST_CLASS": "",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_fn",
|
||||
},
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
extract_test_context_from_env()
|
||||
|
||||
|
||||
class TestSchemaValidation:
|
||||
"""Validate the async_results SQLite schema."""
|
||||
|
||||
def test_table_columns(self, tmp_path) -> None:
|
||||
"""async_results table has exactly 13 columns."""
|
||||
db_path = tmp_path / "schema_test.sqlite"
|
||||
conn, cur = _get_async_db(db_path)
|
||||
cur.execute("PRAGMA table_info(async_results)")
|
||||
columns = cur.fetchall()
|
||||
expected_names = [
|
||||
"test_module_path",
|
||||
"test_class_name",
|
||||
"test_function_name",
|
||||
"function_getting_tested",
|
||||
"loop_index",
|
||||
"invocation_id",
|
||||
"mode",
|
||||
"wall_time_ns",
|
||||
"return_value",
|
||||
"verification_type",
|
||||
"sequential_time_ns",
|
||||
"concurrent_time_ns",
|
||||
"concurrency_factor",
|
||||
]
|
||||
actual_names = [col[1] for col in columns]
|
||||
assert expected_names == actual_names
|
||||
conn.close()
|
||||
_connections.pop(str(db_path), None)
|
||||
|
|
@ -499,7 +499,7 @@ class TestAsyncCallInstrumenter:
|
|||
return transformer, tree
|
||||
|
||||
def test_instruments_await_call(self) -> None:
|
||||
"""Adds env var assignment before an awaited target call."""
|
||||
"""Adds call-site contextvar set before an awaited target call."""
|
||||
code = textwrap.dedent("""\
|
||||
async def test_it():
|
||||
result = await target_func(1, 2)
|
||||
|
|
@ -507,7 +507,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert "_codeflash_call_site.set(" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_skips_non_test_async_functions(self) -> None:
|
||||
|
|
@ -519,7 +519,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" not in source
|
||||
assert "_codeflash_call_site" not in source
|
||||
assert transformer.did_instrument is False
|
||||
|
||||
def test_skips_non_test_sync_functions(self) -> None:
|
||||
|
|
@ -531,7 +531,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" not in source
|
||||
assert "_codeflash_call_site" not in source
|
||||
|
||||
def test_instruments_sync_test_with_await(self) -> None:
|
||||
"""Instruments sync test_ functions that contain awaited calls."""
|
||||
|
|
@ -542,7 +542,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert "_codeflash_call_site.set(" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_multiple_awaits_get_incrementing_ids(self) -> None:
|
||||
|
|
@ -569,7 +569,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert "_codeflash_call_site.set(" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_recurses_into_class_body(self) -> None:
|
||||
|
|
@ -582,7 +582,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert "_codeflash_call_site.set(" in source
|
||||
assert transformer.did_instrument is True
|
||||
|
||||
def test_no_match_when_position_wrong(self) -> None:
|
||||
|
|
@ -608,7 +608,7 @@ class TestAsyncCallInstrumenter:
|
|||
transformer, tree = self._make_transformer(code)
|
||||
new_tree = transformer.visit(tree)
|
||||
source = ast.unparse(new_tree)
|
||||
assert "CODEFLASH_CURRENT_LINE_ID" in source
|
||||
assert "_codeflash_call_site.set(" in source
|
||||
|
||||
def test_ignores_non_target_awaits(self) -> None:
|
||||
"""Does not instrument awaits of unrelated functions."""
|
||||
|
|
@ -1083,7 +1083,10 @@ class TestInjectAsyncProfilingIntoExistingTest:
|
|||
)
|
||||
assert ok is True
|
||||
assert source is not None
|
||||
assert "import os" in source
|
||||
assert (
|
||||
"from codeflash_async_wrapper import _codeflash_call_site"
|
||||
in source
|
||||
)
|
||||
|
||||
def test_no_instrumentation(self, tmp_path: Path) -> None:
|
||||
"""Returns (False, None) when test does not call target."""
|
||||
|
|
@ -1148,7 +1151,7 @@ class TestInjectAsyncProfilingIntoExistingTest:
|
|||
)
|
||||
assert ok is True
|
||||
assert source is not None
|
||||
assert source.count("CODEFLASH_CURRENT_LINE_ID") == 2
|
||||
assert source.count("_codeflash_call_site.set(") == 2
|
||||
|
||||
|
||||
class TestAddAsyncDecoratorToFunction:
|
||||
|
|
|
|||
Loading…
Reference in a new issue