fix: capture stdout in async decorator and fix result merger

The async behavior decorator now captures stdout per invocation via
io.StringIO into a new `stdout` column in the async_results SQLite
table. The result merger prefers data-sourced stdout over XML stdout,
fixing the root cause of empty stdout in merged async results.

Also fixes: duplicate async parse block in _parse_results.py,
CODEFLASH_RUN_TMPDIR propagation to subprocesses, and removes
dead async code from _stdout_parsers.py and _wrap_decorator.py.
This commit is contained in:
Kevin Turcios 2026-04-24 04:35:02 -05:00
parent 629d7f9f08
commit c9f65aba6b
14 changed files with 172 additions and 1296 deletions

View file

@ -7,6 +7,7 @@ running concurrency benchmarks, and evaluating async candidates.
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
import attrs
@ -21,8 +22,6 @@ from ..testing._parse_results import parse_test_results
from ..testing._test_runner import async_run_benchmarking_tests
if TYPE_CHECKING:
from pathlib import Path
from .._model import FunctionToOptimize
from ..benchmarking.models import ConcurrencyMetrics
from ..context.models import CodeOptimizationContext
@ -49,12 +48,16 @@ async def collect_baseline_async_metrics( # noqa: PLR0913
Returns an evolved baseline with the metrics attached.
"""
from ..testing._stdout_parsers import ( # noqa: PLC0415
calculate_function_throughput_from_test_results,
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
get_run_tmp_file,
)
from ..testing._async_data_parser import ( # noqa: PLC0415
calculate_async_throughput,
)
async_throughput = calculate_function_throughput_from_test_results(
baseline.benchmarking_test_results,
async_db = get_run_tmp_file(Path("async_results_0.sqlite"))
async_throughput = calculate_async_throughput(
async_db,
func.function_name,
)
log.info(
@ -103,13 +106,16 @@ async def run_concurrency_benchmark(
return None
from .._model import TestingMode # noqa: PLC0415
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
get_run_tmp_file,
)
from ..testing._async_data_parser import ( # noqa: PLC0415
parse_async_concurrency_metrics,
)
from ..testing._instrumentation import ( # noqa: PLC0415
add_async_decorator_to_function,
revert_instrumented_files,
)
from ..testing._stdout_parsers import ( # noqa: PLC0415
parse_concurrency_metrics,
)
originals: dict[Path, str] = {}
try:
@ -138,7 +144,7 @@ async def run_concurrency_benchmark(
max_loops=3,
target_duration_seconds=5.0,
)
bench_results = parse_test_results(
parse_test_results(
test_xml_path=bench_xml,
test_files=test_files,
test_config=ctx.test_cfg,
@ -155,8 +161,12 @@ async def run_concurrency_benchmark(
if originals:
revert_instrumented_files(originals)
return parse_concurrency_metrics(
bench_results,
iteration = 0
async_db = get_run_tmp_file(
Path(f"async_results_{iteration}.sqlite"),
)
return parse_async_concurrency_metrics(
async_db,
func.function_name,
)
@ -176,8 +186,11 @@ async def evaluate_async_candidate( # noqa: PLR0913
Returns *(speedup, acceptance_reason)*. *speedup* is ``None``
when the candidate is rejected.
"""
from ..testing._stdout_parsers import ( # noqa: PLC0415
calculate_function_throughput_from_test_results,
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
get_run_tmp_file,
)
from ..testing._async_data_parser import ( # noqa: PLC0415
calculate_async_throughput,
)
from ..verification._critic import ( # noqa: PLC0415
get_acceptance_reason,
@ -189,8 +202,12 @@ async def evaluate_async_candidate( # noqa: PLR0913
from ._test_orchestrator import build_test_env # noqa: PLC0415
func = fn_input.function
candidate_throughput = calculate_function_throughput_from_test_results(
bench_results,
iteration = 0
async_db = get_run_tmp_file(
Path(f"async_results_{iteration}.sqlite"),
)
candidate_throughput = calculate_async_throughput(
async_db,
func.function_name,
)

View file

@ -12,8 +12,10 @@ import asyncio
import atexit
import contextvars
import gc
import io
import os
import sqlite3
import sys
import time
from enum import Enum
from functools import wraps
@ -36,12 +38,22 @@ F = TypeVar("F", bound=Callable[..., Any])
def get_run_tmp_file(file_path: Path) -> Path:
"""Return a path inside a persistent per-run temporary directory."""
"""Return a path inside a persistent per-run temporary directory.
Uses ``CODEFLASH_RUN_TMPDIR`` if set (subprocess case), otherwise
creates a new tmpdir and exports the env var so child processes
share the same directory.
"""
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory( # type: ignore[attr-defined]
prefix="codeflash_"
)
return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined]
env_dir = os.environ.get("CODEFLASH_RUN_TMPDIR")
if env_dir and Path(env_dir).is_dir():
get_run_tmp_file.tmpdir = env_dir # type: ignore[attr-defined]
else:
td = TemporaryDirectory(prefix="codeflash_")
get_run_tmp_file.tmpdir = td.name # type: ignore[attr-defined]
get_run_tmp_file.td_ref = td # type: ignore[attr-defined]
os.environ["CODEFLASH_RUN_TMPDIR"] = td.name
return Path(get_run_tmp_file.tmpdir) / file_path # type: ignore[attr-defined]
def extract_test_context_from_env() -> tuple[str, str | None, str]:
@ -82,7 +94,8 @@ _CREATE_TABLE_SQL = (
"verification_type TEXT, "
"sequential_time_ns INTEGER, "
"concurrent_time_ns INTEGER, "
"concurrency_factor INTEGER"
"concurrency_factor INTEGER, "
"stdout TEXT"
")"
)
@ -155,6 +168,9 @@ def codeflash_behavior_async(func: F) -> F:
conn, cur = _get_async_db(db_path)
exception = None
captured_stdout = io.StringIO()
old_stdout = sys.stdout
sys.stdout = captured_stdout
counter = loop.time()
gc.disable()
try:
@ -167,6 +183,9 @@ def codeflash_behavior_async(func: F) -> F:
exception = e
finally:
gc.enable()
sys.stdout = old_stdout
stdout_text = captured_stdout.getvalue()
pickled = (
pickle.dumps(exception)
@ -175,7 +194,7 @@ def codeflash_behavior_async(func: F) -> F:
)
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
@ -190,6 +209,7 @@ def codeflash_behavior_async(func: F) -> F:
None,
None,
None,
stdout_text,
),
)
conn.commit()
@ -254,7 +274,7 @@ def codeflash_performance_async(func: F) -> F:
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
@ -269,6 +289,7 @@ def codeflash_performance_async(func: F) -> F:
None,
None,
None,
None,
),
)
conn.commit()
@ -324,7 +345,7 @@ def codeflash_concurrency_async(func: F) -> F:
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
@ -339,6 +360,7 @@ def codeflash_concurrency_async(func: F) -> F:
sequential_time,
concurrent_time,
concurrency_factor,
None,
),
)
conn.commit()

View file

@ -1,42 +1,47 @@
"""Async wrapper decorators for behavior, performance, and concurrency testing."""
"""Shared runtime helpers used by sync instrumentation.
Async decorators have moved to ``_codeflash_async_decorators.py``.
This module retains ``VerificationType``, ``get_run_tmp_file``, and
``extract_test_context_from_env`` which are still used by the sync
capture path (``_codeflash_capture.py``) and multiple test/analysis
modules.
"""
# ruff: noqa: T201, BLE001
from __future__ import annotations
import asyncio
import gc
import os
import sqlite3
import time
from enum import Enum
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, TypeVar
import dill as pickle
class VerificationType(
str, Enum
): # moved from codeflash/verification/codeflash_capture.py
class VerificationType(str, Enum):
"""Type of correctness verification for captured test data."""
FUNCTION_CALL = "function_call" # Correctness verification for a test function, checks input values and output values)
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
F = TypeVar("F", bound=Callable[..., Any])
FUNCTION_CALL = "function_call"
INIT_STATE_FTO = "init_state_fto"
INIT_STATE_HELPER = "init_state_helper"
def get_run_tmp_file(
file_path: Path,
) -> Path: # moved from codeflash/code_utils/code_utils.py
"""Return a path inside a persistent per-run temporary directory."""
) -> Path:
"""Return a path inside a persistent per-run temporary directory.
Uses ``CODEFLASH_RUN_TMPDIR`` if set (subprocess case), otherwise
creates a new tmpdir and exports the env var so child processes
share the same directory.
"""
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") # type: ignore[attr-defined]
return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined]
env_dir = os.environ.get("CODEFLASH_RUN_TMPDIR")
if env_dir and Path(env_dir).is_dir():
get_run_tmp_file.tmpdir = env_dir # type: ignore[attr-defined]
else:
td = TemporaryDirectory(prefix="codeflash_")
get_run_tmp_file.tmpdir = td.name # type: ignore[attr-defined]
get_run_tmp_file.td_ref = td # type: ignore[attr-defined]
os.environ["CODEFLASH_RUN_TMPDIR"] = td.name
return Path(get_run_tmp_file.tmpdir) / file_path # type: ignore[attr-defined]
def extract_test_context_from_env() -> tuple[str, str | None, str]:
@ -51,191 +56,3 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
raise RuntimeError( # noqa: TRY003
"Test context environment variables not set - ensure tests are run through codeflash test runner" # noqa: EM101
)
def codeflash_behavior_async(func: F) -> F:
"""Decorator capturing async function return values and timing for behavioral tests."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Await the wrapped coroutine and record its result to SQLite."""
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = (
extract_test_context_from_env()
)
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {} # type: ignore[attr-defined]
if test_id in async_wrapper.index: # type: ignore[attr-defined]
async_wrapper.index[test_id] += 1 # type: ignore[attr-defined]
else:
async_wrapper.index[test_id] = 0 # type: ignore[attr-defined]
codeflash_test_index = async_wrapper.index[test_id] # type: ignore[attr-defined]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(
Path(f"test_return_values_{iteration}.sqlite")
)
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT, cpu_runtime INTEGER)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(
*args, **kwargs
) # coroutine creation has some overhead, though it is very small
counter = loop.time()
return_value = (
await ret
) # let's measure the actual execution time of the code
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = (
pickle.dumps(exception)
if exception
else pickle.dumps((args, kwargs, return_value))
)
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
VerificationType.FUNCTION_CALL.value,
0,
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper # type: ignore[return-value]
def codeflash_performance_async(func: F) -> F:
"""Decorator measuring async function execution time for performance tests."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Await the wrapped coroutine and emit its timing via stdout."""
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = (
extract_test_context_from_env()
)
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {} # type: ignore[attr-defined]
if test_id in async_wrapper.index: # type: ignore[attr-defined]
async_wrapper.index[test_id] += 1 # type: ignore[attr-defined]
else:
async_wrapper.index[test_id] = 0 # type: ignore[attr-defined]
codeflash_test_index = async_wrapper.index[test_id] # type: ignore[attr-defined]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper # type: ignore[return-value]
def codeflash_concurrency_async(func: F) -> F:
"""Measures concurrent vs sequential execution performance for async functions."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Run sequential then concurrent executions and emit timing metrics."""
function_name = func.__name__
concurrency_factor = int(
os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")
)
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
# Phase 1: Sequential execution timing
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
# Phase 2: Concurrent execution timing
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
# Output parseable metrics
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(
f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!"
)
return result
return async_wrapper # type: ignore[return-value]

View file

@ -23,7 +23,7 @@ _BEHAVIOR_QUERY = (
"SELECT test_module_path, test_class_name,"
" test_function_name, function_getting_tested,"
" loop_index, invocation_id, wall_time_ns,"
" return_value, verification_type"
" return_value, verification_type, stdout"
" FROM async_results"
" WHERE mode = 'behavior'"
)
@ -109,6 +109,7 @@ def _process_behavior_row_inner(
invocation_id = val[5]
wall_time_ns = val[6]
verification_type = val[8]
stdout_text = val[9] if len(val) > 9 else None
test_file_path = file_path_from_module_name(
test_module_path, # type: ignore[arg-type]
@ -173,6 +174,7 @@ def _process_behavior_row_inner(
if verification_type
else None
),
stdout=stdout_text or None,
),
)

View file

@ -11,6 +11,7 @@ from pathlib import Path
from typing import TYPE_CHECKING
from ..runtime._codeflash_wrap_decorator import get_run_tmp_file
from ._async_data_parser import parse_async_behavior_results
from ._data_parsers import parse_sqlite_test_results
from ._result_merger import merge_test_results
from ._stdout_parsers import parse_test_failures_from_stdout
@ -50,6 +51,20 @@ def parse_test_results(
sql_file, test_files, test_config
)
# Parse async SQLite results
async_sql_file = get_run_tmp_file(
Path(f"async_results_{optimization_iteration}.sqlite"),
)
if async_sql_file.exists():
async_results = parse_async_behavior_results(
async_sql_file,
test_files,
test_config,
)
for inv in async_results:
data_results.test_results.append(inv)
async_sql_file.unlink(missing_ok=True)
# Clean up deprecated binary pickle file if present
bin_file = get_run_tmp_file(
Path(f"test_return_values_{optimization_iteration}.bin"),

View file

@ -92,7 +92,7 @@ def _merge_single_xml(
if data_result.verification_type
else None
),
stdout=xml_result.stdout,
stdout=data_result.stdout or xml_result.stdout,
),
)

View file

@ -1,14 +1,8 @@
"""Stdout-based parsing: test failures and performance/concurrency metrics."""
"""Stdout-based parsing: test failure extraction from pytest output."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from ..benchmarking.models import ConcurrencyMetrics
if TYPE_CHECKING:
from .models import TestResults
TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$")
@ -78,82 +72,3 @@ def _collect_failures(
failures[current_name] = "".join(current_lines)
return failures
# -- Performance and concurrency metrics --
_perf_start_pattern = re.compile(
r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!",
)
_perf_end_pattern = re.compile(
r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!",
)
_concurrency_pattern = re.compile(
r"!@######CONC:"
r"([^:]*):([^:]*):([^:]*):([^:]*):([^:]*)"
r":(\d+):(\d+):(\d+)######@!",
)
def calculate_function_throughput_from_test_results(
test_results: TestResults,
function_name: str,
) -> int:
"""Count completed function executions from performance stdout markers."""
start_matches = _perf_start_pattern.findall(
test_results.perf_stdout or "",
)
end_matches = _perf_end_pattern.findall(
test_results.perf_stdout or "",
)
end_matches_truncated = [m[:5] for m in end_matches]
end_matches_set = set(end_matches_truncated)
count = 0
expected_fn_idx = 2
for start_match in start_matches:
if (
start_match in end_matches_set
and len(start_match) > expected_fn_idx
and start_match[expected_fn_idx] == function_name
):
count += 1
return count
def parse_concurrency_metrics(
test_results: TestResults,
function_name: str,
) -> ConcurrencyMetrics | None:
"""Parse concurrency benchmark results from test output."""
if not test_results.perf_stdout:
return None
matches = _concurrency_pattern.findall(test_results.perf_stdout)
if not matches:
return None
expected_groups = 8
total_seq, total_conc, factor, count = 0, 0, 0, 0
for match in matches:
if len(match) >= expected_groups and match[3] == function_name:
total_seq += int(match[5])
total_conc += int(match[6])
factor = int(match[7])
count += 1
if count == 0:
return None
avg_seq = total_seq / count
avg_conc = total_conc / count
ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0
return ConcurrencyMetrics(
sequential_time_ns=int(avg_seq),
concurrent_time_ns=int(avg_conc),
concurrency_factor=factor,
concurrency_ratio=ratio,
)

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import logging
import os
import shlex
import subprocess
import sys
@ -23,6 +24,13 @@ _PER_FILE_TIMEOUT = 60
_MAX_TIMEOUT = 600
def _propagate_tmpdir(env: dict[str, str]) -> None:
"""Ensure CODEFLASH_RUN_TMPDIR is in the subprocess env."""
tmpdir = os.environ.get("CODEFLASH_RUN_TMPDIR")
if tmpdir:
env["CODEFLASH_RUN_TMPDIR"] = tmpdir
def _base_pytest_args(rootdir: Path | None, cwd: Path) -> list[str]:
"""Common pytest args shared across all test runner functions."""
return [
@ -48,6 +56,8 @@ def execute_test_subprocess(
timeout: int = 600,
) -> subprocess.CompletedProcess[str]:
"""Execute a subprocess with the given command list."""
if env is not None:
_propagate_tmpdir(env)
log.debug(
"executing test run with command: %s",
" ".join(cmd_list),
@ -355,6 +365,8 @@ async def async_execute_test_subprocess(
timeout: int = 600,
) -> subprocess.CompletedProcess[str]:
"""Execute a subprocess asynchronously."""
if env is not None:
_propagate_tmpdir(env)
log.debug(
"executing async test run with command: %s",
" ".join(cmd_list),

View file

@ -1,343 +0,0 @@
from __future__ import annotations
import asyncio
import os
import sys
import time
import pytest
from codeflash_python.benchmarking.models import ConcurrencyMetrics
from codeflash_python.runtime._codeflash_wrap_decorator import (
codeflash_concurrency_async,
)
from codeflash_python.testing._stdout_parsers import parse_concurrency_metrics
from codeflash_python.testing.models import TestResults
@pytest.mark.skipif(
sys.platform == "win32", reason="pending support for asyncio on windows"
)
class TestConcurrencyAsyncDecorator:
"""Integration tests for codeflash_concurrency_async decorator."""
@pytest.fixture
def concurrency_env_setup(self, request):
"""Set up environment variables for concurrency testing."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestConcurrencyAsyncDecorator",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CONCURRENCY_FACTOR": "5", # Use smaller factor for faster tests
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.mark.asyncio
async def test_concurrency_decorator_nonblocking_function(
self, concurrency_env_setup, capsys
):
"""Test that non-blocking async functions show high concurrency ratio."""
@codeflash_concurrency_async
async def nonblocking_sleep(duration: float) -> str:
await asyncio.sleep(duration)
return "done"
result = await nonblocking_sleep(0.01)
assert result == "done"
captured = capsys.readouterr()
output = captured.out
# Verify the output format
assert "!@######CONC:" in output
assert "######@!" in output
# Parse the output manually to verify format
lines = [
line
for line in output.strip().split("\n")
if "!@######CONC:" in line
]
assert len(lines) == 1
line = lines[0]
# Format: !@######CONC:{test_module}:{test_class}:{test_function}:{function_name}:{loop_index}:{seq_time}:{conc_time}:{factor}######@!
assert "nonblocking_sleep" in line
assert ":5######@!" in line # concurrency factor
# Extract timing values
parts = (
line.replace("!@######CONC:", "")
.replace("######@!", "")
.split(":")
)
# parts should be: [test_module, test_class, test_function, function_name, loop_index, seq_time, conc_time, factor]
assert len(parts) == 8
seq_time = int(parts[5])
conc_time = int(parts[6])
factor = int(parts[7])
assert seq_time > 0
assert conc_time > 0
assert factor == 5
# For non-blocking async, concurrent time should be much less than sequential
# Sequential runs 5 iterations of 10ms = ~50ms
# Concurrent runs 5 iterations in parallel = ~10ms
# So ratio should be around 5 (with some overhead tolerance)
ratio = seq_time / conc_time if conc_time > 0 else 1.0
assert ratio > 2.0, (
f"Non-blocking function should have ratio > 2.0, got {ratio}"
)
@pytest.mark.asyncio
async def test_concurrency_decorator_blocking_function(
self, concurrency_env_setup, capsys
):
"""Test that blocking functions show low concurrency ratio (~1.0)."""
@codeflash_concurrency_async
async def blocking_sleep(duration: float) -> str:
time.sleep(duration) # Blocking sleep
return "done"
result = await blocking_sleep(0.005) # 5ms blocking
assert result == "done"
captured = capsys.readouterr()
output = captured.out
assert "!@######CONC:" in output
lines = [
line
for line in output.strip().split("\n")
if "!@######CONC:" in line
]
assert len(lines) == 1
line = lines[0]
parts = (
line.replace("!@######CONC:", "")
.replace("######@!", "")
.split(":")
)
assert len(parts) == 8
seq_time = int(parts[5])
conc_time = int(parts[6])
# For blocking code, sequential and concurrent times should be similar
# Because time.sleep blocks the entire event loop
ratio = seq_time / conc_time if conc_time > 0 else 1.0
# Blocking code should have ratio close to 1.0 (within reasonable tolerance)
assert ratio < 2.0, (
f"Blocking function should have ratio < 2.0, got {ratio}"
)
@pytest.mark.asyncio
async def test_concurrency_decorator_with_computation(
self, concurrency_env_setup, capsys
):
"""Test concurrency with CPU-bound computation."""
@codeflash_concurrency_async
async def compute_intensive(n: int) -> int:
# CPU-bound work (blocked by GIL in concurrent execution)
total = 0
for i in range(n):
total += i * i
return total
result = await compute_intensive(10000)
assert result == sum(i * i for i in range(10000))
captured = capsys.readouterr()
output = captured.out
assert "!@######CONC:" in output
assert "compute_intensive" in output
@pytest.mark.skipif(
sys.platform == "win32", reason="pending support for asyncio on windows"
)
class TestParseConcurrencyMetrics:
"""Integration tests for parse_concurrency_metrics function."""
def test_parse_concurrency_metrics_from_real_output(self):
"""Test parsing concurrency metrics from simulated stdout."""
# Simulate stdout from codeflash_concurrency_async decorator
perf_stdout = """Some other output
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
More output here
"""
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
metrics = parse_concurrency_metrics(test_results, "my_async_func")
assert metrics is not None
assert isinstance(metrics, ConcurrencyMetrics)
assert metrics.sequential_time_ns == 50000000
assert metrics.concurrent_time_ns == 10000000
assert metrics.concurrency_factor == 5
assert metrics.concurrency_ratio == 5.0 # 50M / 10M = 5.0
def test_parse_concurrency_metrics_multiple_entries(self):
"""Test parsing when multiple concurrency entries exist."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:target_func:1:40000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
"""
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
metrics = parse_concurrency_metrics(test_results, "target_func")
assert metrics is not None
# Should average the two entries for target_func
# (40M + 60M) / 2 = 50M seq, (10M + 10M) / 2 = 10M conc
assert metrics.sequential_time_ns == 50000000
assert metrics.concurrent_time_ns == 10000000
assert metrics.concurrency_ratio == 5.0
def test_parse_concurrency_metrics_no_match(self):
"""Test parsing when function name doesn't match."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
"""
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
assert metrics is None
def test_parse_concurrency_metrics_empty_stdout(self):
"""Test parsing with empty stdout."""
test_results = TestResults(test_results=[], perf_stdout="")
metrics = parse_concurrency_metrics(test_results, "any_func")
assert metrics is None
def test_parse_concurrency_metrics_none_stdout(self):
"""Test parsing with None stdout."""
test_results = TestResults(test_results=[], perf_stdout=None)
metrics = parse_concurrency_metrics(test_results, "any_func")
assert metrics is None
@pytest.mark.skipif(
sys.platform == "win32", reason="pending support for asyncio on windows"
)
class TestConcurrencyRatioComparison:
"""Test comparing blocking vs non-blocking concurrency ratios."""
@pytest.fixture
def comparison_env_setup(self, request):
"""Set up environment variables for comparison testing."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestConcurrencyRatioComparison",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CONCURRENCY_FACTOR": "10",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.mark.asyncio
async def test_blocking_vs_nonblocking_comparison(
self, comparison_env_setup, capsys
):
"""Compare concurrency ratios between blocking and non-blocking implementations."""
@codeflash_concurrency_async
async def blocking_impl() -> str:
time.sleep(0.002) # 2ms blocking
return "blocking"
@codeflash_concurrency_async
async def nonblocking_impl() -> str:
await asyncio.sleep(0.002) # 2ms non-blocking
return "nonblocking"
# Run blocking version
await blocking_impl()
blocking_output = capsys.readouterr().out
# Run non-blocking version
await nonblocking_impl()
nonblocking_output = capsys.readouterr().out
# Parse blocking metrics
blocking_line = [
l for l in blocking_output.split("\n") if "!@######CONC:" in l
][0]
blocking_parts = (
blocking_line.replace("!@######CONC:", "")
.replace("######@!", "")
.split(":")
)
blocking_seq = int(blocking_parts[5])
blocking_conc = int(blocking_parts[6])
blocking_ratio = (
blocking_seq / blocking_conc if blocking_conc > 0 else 1.0
)
# Parse non-blocking metrics
nonblocking_line = [
l for l in nonblocking_output.split("\n") if "!@######CONC:" in l
][0]
nonblocking_parts = (
nonblocking_line.replace("!@######CONC:", "")
.replace("######@!", "")
.split(":")
)
nonblocking_seq = int(nonblocking_parts[5])
nonblocking_conc = int(nonblocking_parts[6])
nonblocking_ratio = (
nonblocking_seq / nonblocking_conc if nonblocking_conc > 0 else 1.0
)
# Non-blocking should have significantly higher concurrency ratio
assert nonblocking_ratio > blocking_ratio, (
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than blocking ratio ({blocking_ratio:.2f})"
)
# The difference should be substantial (non-blocking should be at least 2x better)
ratio_improvement = (
nonblocking_ratio / blocking_ratio if blocking_ratio > 0 else 0
)
assert ratio_improvement > 2.0, (
f"Non-blocking should show >2x improvement in concurrency ratio, got {ratio_improvement:.2f}x"
)

View file

@ -52,7 +52,7 @@ def _create_async_db(
for row in rows:
conn.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
row,
)
conn.commit()
@ -100,6 +100,7 @@ class TestParseAsyncBehaviorResults:
None,
None,
None,
None,
),
],
)
@ -139,6 +140,7 @@ class TestParseAsyncBehaviorResults:
None,
None,
None,
None,
),
],
)
@ -191,6 +193,7 @@ class TestCalculateAsyncThroughput:
None,
None,
None,
None,
),
(
"mod",
@ -206,6 +209,7 @@ class TestCalculateAsyncThroughput:
None,
None,
None,
None,
),
(
"mod",
@ -221,6 +225,7 @@ class TestCalculateAsyncThroughput:
None,
None,
None,
None,
),
(
"mod",
@ -236,6 +241,7 @@ class TestCalculateAsyncThroughput:
None,
None,
None,
None,
),
],
)
@ -281,6 +287,7 @@ class TestParseAsyncConcurrencyMetrics:
100_000,
50_000,
10,
None,
),
(
"mod",
@ -296,6 +303,7 @@ class TestParseAsyncConcurrencyMetrics:
200_000,
100_000,
10,
None,
),
],
)
@ -326,6 +334,7 @@ class TestParseAsyncConcurrencyMetrics:
100_000,
50_000,
10,
None,
),
],
)
@ -352,6 +361,7 @@ class TestParseAsyncConcurrencyMetrics:
100_000,
0,
5,
None,
),
],
)
@ -406,6 +416,7 @@ class TestParseAsyncBehaviorEdgeCases:
None,
None,
None,
None,
),
],
)
@ -449,6 +460,7 @@ class TestParseAsyncBehaviorEdgeCases:
None,
None,
None,
None,
),
],
)
@ -485,6 +497,7 @@ class TestParseAsyncBehaviorEdgeCases:
None,
None,
None,
None,
),
],
)
@ -517,6 +530,7 @@ class TestParseAsyncBehaviorEdgeCases:
None,
None,
None,
None,
),
],
)
@ -554,6 +568,7 @@ class TestParseAsyncBehaviorEdgeCases:
None,
None,
None,
None,
),
],
)

View file

@ -227,7 +227,7 @@ class TestBehaviorAsync:
async def test_no_stdout_output(
self, env_setup, async_db_path, capsys
) -> None:
"""Behavior decorator emits no stdout."""
"""Behavior decorator does not leak stdout to outer scope."""
@codeflash_behavior_async
async def noop() -> int:
@ -238,6 +238,28 @@ class TestBehaviorAsync:
captured = capsys.readouterr()
assert "" == captured.out
@pytest.mark.asyncio
async def test_captures_stdout_in_sqlite(
self, env_setup, async_db_path
) -> None:
"""Behavior decorator captures print output into the stdout column."""
@codeflash_behavior_async
async def greeter(name: str) -> str:
print(f"hello {name}")
return f"hi {name}"
_codeflash_call_site.set("0")
await greeter("world")
_close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT stdout FROM async_results")
row = cur.fetchone()
assert "hello world\n" == row[0]
con.close()
@pytest.mark.skipif(
sys.platform == "win32",
@ -482,7 +504,7 @@ class TestSchemaValidation:
"""Validate the async_results SQLite schema."""
def test_table_columns(self, tmp_path) -> None:
"""async_results table has exactly 13 columns."""
"""async_results table has exactly 14 columns."""
db_path = tmp_path / "schema_test.sqlite"
conn, cur = _get_async_db(db_path)
cur.execute("PRAGMA table_info(async_results)")
@ -501,6 +523,7 @@ class TestSchemaValidation:
"sequential_time_ns",
"concurrent_time_ns",
"concurrency_factor",
"stdout",
]
actual_names = [col[1] for col in columns]
assert expected_names == actual_names

View file

@ -1,333 +0,0 @@
from __future__ import annotations
import asyncio
import os
import sqlite3
import sys
from pathlib import Path
import dill as pickle
import pytest
from codeflash_python.runtime._codeflash_capture import VerificationType
from codeflash_python.runtime._codeflash_wrap_decorator import (
codeflash_behavior_async,
codeflash_performance_async,
)
@pytest.mark.skipif(
sys.platform == "win32", reason="pending support for asyncio on windows"
)
class TestAsyncWrapperSQLiteValidation:
@pytest.fixture
def test_env_setup(self, request):
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "0",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.fixture
def temp_db_path(self, test_env_setup):
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
from codeflash_python.testing._instrumentation import get_run_tmp_file
db_path = get_run_tmp_file(
Path(f"test_return_values_{iteration}.sqlite")
)
yield db_path
if db_path.exists():
db_path.unlink()
@pytest.mark.asyncio
async def test_behavior_async_basic_function(
self, test_env_setup, temp_db_path
):
@codeflash_behavior_async
async def simple_async_add(a: int, b: int) -> int:
await asyncio.sleep(0.001)
return a + b
os.environ["CODEFLASH_CURRENT_LINE_ID"] = "simple_async_add_59"
result = await simple_async_add(5, 3)
assert result == 8
assert temp_db_path.exists()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'"
)
assert cur.fetchone() is not None
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1
row = rows[0]
(
test_module_path,
test_class_name,
test_function_name,
function_getting_tested,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
cpu_runtime,
) = row
assert test_module_path == __name__
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
assert test_function_name == "test_behavior_async_basic_function"
assert function_getting_tested == "simple_async_add"
assert loop_index == 1
# Line ID will be the actual line number from the source code, not a simple counter
assert iteration_id.startswith(
"simple_async_add_"
) and iteration_id.endswith("_0")
assert runtime > 0
assert verification_type == VerificationType.FUNCTION_CALL.value
assert isinstance(cpu_runtime, int)
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, return_val = unpickled_data
assert args == (5, 3)
assert kwargs == {}
assert return_val == 8
con.close()
@pytest.mark.asyncio
async def test_behavior_async_exception_handling(
self, test_env_setup, temp_db_path
):
@codeflash_behavior_async
async def async_divide(a: int, b: int) -> float:
await asyncio.sleep(0.001)
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
result = await async_divide(10, 2)
assert result == 5.0
with pytest.raises(ValueError, match="Cannot divide by zero"):
await async_divide(10, 0)
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 2
success_row = rows[0]
success_data = pickle.loads(success_row[7]) # return_value_blob
args, kwargs, return_val = success_data
assert args == (10, 2)
assert return_val == 5.0
# Check exception record
exception_row = rows[1]
exception_data = pickle.loads(exception_row[7]) # return_value_blob
assert isinstance(exception_data, ValueError)
assert str(exception_data) == "Cannot divide by zero"
con.close()
@pytest.mark.asyncio
async def test_performance_async_no_database_storage(
self, test_env_setup, temp_db_path, capsys
):
"""Test performance async decorator doesn't store to database."""
@codeflash_performance_async
async def async_multiply(a: int, b: int) -> int:
"""Async function for performance testing."""
await asyncio.sleep(0.002)
return a * b
result = await async_multiply(4, 7)
assert result == 28
assert not temp_db_path.exists()
captured = capsys.readouterr()
output_lines = captured.out.strip().split("\n")
assert len([line for line in output_lines if "!$######" in line]) == 1
assert (
len(
[
line
for line in output_lines
if "!######" in line and "######!" in line
]
)
== 1
)
closing_tag = [
line
for line in output_lines
if "!######" in line and "######!" in line
][0]
assert "async_multiply" in closing_tag
timing_part = closing_tag.split(":")[-1].replace("######!", "")
timing_value = int(timing_part)
assert timing_value > 0 # Should have positive timing
@pytest.mark.asyncio
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_increment(value: int) -> int:
await asyncio.sleep(0.001)
return value + 1
# Call the function multiple times
results = []
for i in range(3):
result = await async_increment(i)
results.append(result)
assert results == [1, 2, 3]
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute(
"SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id"
)
rows = cur.fetchall()
assert len(rows) == 3
actual_ids = [row[0] for row in rows]
assert len(actual_ids) == 3
base_pattern = actual_ids[0].rsplit("_", 1)[
0
] # e.g., "async_increment_199"
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
assert actual_ids == expected_pattern
for i, (_, return_value_blob) in enumerate(rows):
args, kwargs, return_val = pickle.loads(return_value_blob)
assert args == (i,)
assert return_val == i + 1
con.close()
@pytest.mark.asyncio
async def test_complex_async_function_with_kwargs(
self, test_env_setup, temp_db_path
):
@codeflash_behavior_async
async def complex_async_func(
pos_arg: str,
*args: int,
keyword_arg: str = "default",
**kwargs: str,
) -> dict:
await asyncio.sleep(0.001)
return {
"pos_arg": pos_arg,
"args": args,
"keyword_arg": keyword_arg,
"kwargs": kwargs,
}
result = await complex_async_func(
"hello",
1,
2,
3,
keyword_arg="custom",
extra1="value1",
extra2="value2",
)
expected_result = {
"pos_arg": "hello",
"args": (1, 2, 3),
"keyword_arg": "custom",
"kwargs": {"extra1": "value1", "extra2": "value2"},
}
assert result == expected_result
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM test_results")
row = cur.fetchone()
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
assert stored_args == ("hello", 1, 2, 3)
assert stored_kwargs == {
"keyword_arg": "custom",
"extra1": "value1",
"extra2": "value2",
}
assert stored_result == expected_result
con.close()
@pytest.mark.asyncio
async def test_database_schema_validation(
self, test_env_setup, temp_db_path
):
@codeflash_behavior_async
async def schema_test_func() -> str:
return "schema_test"
await schema_test_func()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("PRAGMA table_info(test_results)")
columns = cur.fetchall()
expected_columns = [
(0, "test_module_path", "TEXT", 0, None, 0),
(1, "test_class_name", "TEXT", 0, None, 0),
(2, "test_function_name", "TEXT", 0, None, 0),
(3, "function_getting_tested", "TEXT", 0, None, 0),
(4, "loop_index", "INTEGER", 0, None, 0),
(5, "iteration_id", "TEXT", 0, None, 0),
(6, "runtime", "INTEGER", 0, None, 0),
(7, "return_value", "BLOB", 0, None, 0),
(8, "verification_type", "TEXT", 0, None, 0),
(9, "cpu_runtime", "INTEGER", 0, None, 0),
]
assert columns == expected_columns
con.close()

View file

@ -13,7 +13,6 @@ from codeflash_python.analysis._coverage import (
from codeflash_python.benchmarking.models import ConcurrencyMetrics
from codeflash_python.context.models import CodeOptimizationContext
from codeflash_python.test_discovery.models import TestType
from codeflash_python.testing._stdout_parsers import parse_concurrency_metrics
from codeflash_python.testing.models import (
FunctionTestInvocation,
InvocationId,
@ -882,43 +881,3 @@ def test_concurrency_ratio_display_formatting() -> None:
assert display_string == "Concurrency ratio: 0.01x \u2192 0.03x (+200.0%)"
def test_parse_concurrency_metrics() -> None:
"""parse_concurrency_metrics extracts metrics from test output."""
stdout = (
"!@######CONC:test_module:TestClass:test_func:"
"my_function:0:10000000:1000000:10######@!\n"
"!@######CONC:test_module:TestClass:test_func:"
"my_function:1:10000000:1000000:10######@!\n"
)
test_results = TestResults(perf_stdout=stdout)
metrics = parse_concurrency_metrics(test_results, "my_function")
assert metrics is not None
assert metrics.sequential_time_ns == 10_000_000
assert metrics.concurrent_time_ns == 1_000_000
assert metrics.concurrency_factor == 10
assert metrics.concurrency_ratio == 10.0
metrics_wrong_func = parse_concurrency_metrics(
test_results, "other_function"
)
assert metrics_wrong_func is None
empty_results = TestResults(perf_stdout="")
metrics_empty = parse_concurrency_metrics(empty_results, "my_function")
assert metrics_empty is None
none_results = TestResults(perf_stdout=None)
metrics_none = parse_concurrency_metrics(none_results, "my_function")
assert metrics_none is None
stdout_no_class = (
"!@######CONC:test_module::test_func:"
"my_function:0:5000000:2500000:10######@!\n"
)
test_results_no_class = TestResults(perf_stdout=stdout_no_class)
metrics_no_class = parse_concurrency_metrics(
test_results_no_class, "my_function"
)
assert metrics_no_class is not None
assert metrics_no_class.concurrency_ratio == 2.0

View file

@ -1,4 +1,3 @@
import os
import sys
import tempfile
from pathlib import Path
@ -750,272 +749,28 @@ async def test_multiple_calls():
assert instrumented_test_code is not None
assert (
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'"
"_codeflash_call_site.set('0')"
in instrumented_test_code
)
# Count occurrences of each line_id to verify numbering
line_id_0_count = instrumented_test_code.count(
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'"
"_codeflash_call_site.set('0')"
)
line_id_1_count = instrumented_test_code.count(
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'"
"_codeflash_call_site.set('1')"
)
line_id_2_count = instrumented_test_code.count(
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'"
"_codeflash_call_site.set('2')"
)
assert line_id_0_count == 2, (
assert 2 == line_id_0_count, (
f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
)
assert line_id_1_count == 1, (
assert 1 == line_id_1_count, (
f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
)
assert line_id_2_count == 1, (
assert 1 == line_id_2_count, (
f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
)
@pytest.mark.skipif(
sys.platform == "win32", reason="pending support for asyncio on windows"
)
def test_async_behavior_decorator_return_values_and_test_ids():
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
import asyncio
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash_python.runtime._codeflash_wrap_decorator import (
codeflash_behavior_async,
)
@codeflash_behavior_async
async def test_async_multiply(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.001) # Small delay to simulate async work
return x * y
test_env = {
"CODEFLASH_TEST_MODULE": "test_module",
"CODEFLASH_TEST_CLASS": None,
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
"CODEFLASH_CURRENT_LINE_ID": "0",
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "2",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
result = asyncio.run(test_async_multiply(6, 7))
assert result == 42, f"Expected return value 42, got {result}"
from codeflash_python.testing._instrumentation import get_run_tmp_file
db_path = get_run_tmp_file(Path("test_return_values_2.sqlite"))
# Verify database exists and has data
assert db_path.exists(), f"Database file not created at {db_path}"
# Read and verify database contents
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
row = rows[0]
(
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
cpu_runtime,
) = row
assert test_module == "test_module", (
f"Expected test_module 'test_module', got '{test_module}'"
)
assert test_class is None, (
f"Expected test_class None, got '{test_class}'"
)
assert test_function == "test_async_multiply_function", (
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
)
assert function_name == "test_async_multiply", (
f"Expected function_name 'test_async_multiply', got '{function_name}'"
)
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
assert iteration_id == "0_0", (
f"Expected iteration_id '0_0', got '{iteration_id}'"
)
assert verification_type == "function_call", (
f"Expected verification_type 'function_call', got '{verification_type}'"
)
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, actual_return_value = unpickled_data
assert args == (6, 7), f"Expected args (6, 7), got {args}"
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
assert actual_return_value == 42, (
f"Expected stored return value 42, got {actual_return_value}"
)
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
@pytest.mark.skipif(
sys.platform == "win32", reason="pending support for asyncio on windows"
)
def test_async_decorator_comprehensive_return_values_and_test_ids():
import asyncio
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash_python.runtime._codeflash_wrap_decorator import (
codeflash_behavior_async,
)
from codeflash_python.testing._instrumentation import get_run_tmp_file
@codeflash_behavior_async
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
"""Async function that multiplies x*y then adds z."""
await asyncio.sleep(0.001)
result = (x * y) + z
return result
test_env = {
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
"CODEFLASH_CURRENT_LINE_ID": "3",
"CODEFLASH_LOOP_INDEX": "2",
"CODEFLASH_TEST_ITERATION": "3",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
test_cases = [
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
{
"args": (2, 4),
"kwargs": {"z": 10},
"expected": 18,
}, # (2 * 4) + 10 = 18
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
]
results = []
for test_case in test_cases:
result = asyncio.run(
async_multiply_add(*test_case["args"], **test_case["kwargs"])
)
results.append(result)
# Verify each return value is exactly correct
assert result == test_case["expected"], (
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
)
db_path = get_run_tmp_file(Path("test_return_values_3.sqlite"))
assert db_path.exists(), f"Database not created at {db_path}"
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute(
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
)
rows = cur.fetchall()
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
for i, (
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) in enumerate(rows):
assert test_module == "test_comprehensive_module", (
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
)
assert test_class == "AsyncTestClass", (
f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
)
assert test_function == "test_comprehensive_async_function", (
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
)
assert function_name == "async_multiply_add", (
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
)
assert loop_index == 2, (
f"Row {i}: Expected loop_index 2, got {loop_index}"
)
assert verification_type == "function_call", (
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
)
expected_iteration_id = f"3_{i}"
assert iteration_id == expected_iteration_id, (
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
)
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
expected_args = test_cases[i]["args"]
expected_kwargs = test_cases[i]["kwargs"]
expected_return = test_cases[i]["expected"]
assert args == expected_args, (
f"Row {i}: Expected args {expected_args}, got {args}"
)
assert kwargs == expected_kwargs, (
f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
)
assert actual_return_value == expected_return, (
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
)
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]