mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
fix: capture stdout in async decorator and fix result merger
The async behavior decorator now captures stdout per invocation via io.StringIO into a new `stdout` column in the async_results SQLite table. The result merger prefers data-sourced stdout over XML stdout, fixing the root cause of empty stdout in merged async results. Also fixes: duplicate async parse block in _parse_results.py, CODEFLASH_RUN_TMPDIR propagation to subprocesses, and removes dead async code from _stdout_parsers.py and _wrap_decorator.py.
This commit is contained in:
parent
629d7f9f08
commit
c9f65aba6b
14 changed files with 172 additions and 1296 deletions
|
|
@ -7,6 +7,7 @@ running concurrency benchmarks, and evaluating async candidates.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import attrs
|
||||
|
|
@ -21,8 +22,6 @@ from ..testing._parse_results import parse_test_results
|
|||
from ..testing._test_runner import async_run_benchmarking_tests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from .._model import FunctionToOptimize
|
||||
from ..benchmarking.models import ConcurrencyMetrics
|
||||
from ..context.models import CodeOptimizationContext
|
||||
|
|
@ -49,12 +48,16 @@ async def collect_baseline_async_metrics( # noqa: PLR0913
|
|||
|
||||
Returns an evolved baseline with the metrics attached.
|
||||
"""
|
||||
from ..testing._stdout_parsers import ( # noqa: PLC0415
|
||||
calculate_function_throughput_from_test_results,
|
||||
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
|
||||
get_run_tmp_file,
|
||||
)
|
||||
from ..testing._async_data_parser import ( # noqa: PLC0415
|
||||
calculate_async_throughput,
|
||||
)
|
||||
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
baseline.benchmarking_test_results,
|
||||
async_db = get_run_tmp_file(Path("async_results_0.sqlite"))
|
||||
async_throughput = calculate_async_throughput(
|
||||
async_db,
|
||||
func.function_name,
|
||||
)
|
||||
log.info(
|
||||
|
|
@ -103,13 +106,16 @@ async def run_concurrency_benchmark(
|
|||
return None
|
||||
|
||||
from .._model import TestingMode # noqa: PLC0415
|
||||
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
|
||||
get_run_tmp_file,
|
||||
)
|
||||
from ..testing._async_data_parser import ( # noqa: PLC0415
|
||||
parse_async_concurrency_metrics,
|
||||
)
|
||||
from ..testing._instrumentation import ( # noqa: PLC0415
|
||||
add_async_decorator_to_function,
|
||||
revert_instrumented_files,
|
||||
)
|
||||
from ..testing._stdout_parsers import ( # noqa: PLC0415
|
||||
parse_concurrency_metrics,
|
||||
)
|
||||
|
||||
originals: dict[Path, str] = {}
|
||||
try:
|
||||
|
|
@ -138,7 +144,7 @@ async def run_concurrency_benchmark(
|
|||
max_loops=3,
|
||||
target_duration_seconds=5.0,
|
||||
)
|
||||
bench_results = parse_test_results(
|
||||
parse_test_results(
|
||||
test_xml_path=bench_xml,
|
||||
test_files=test_files,
|
||||
test_config=ctx.test_cfg,
|
||||
|
|
@ -155,8 +161,12 @@ async def run_concurrency_benchmark(
|
|||
if originals:
|
||||
revert_instrumented_files(originals)
|
||||
|
||||
return parse_concurrency_metrics(
|
||||
bench_results,
|
||||
iteration = 0
|
||||
async_db = get_run_tmp_file(
|
||||
Path(f"async_results_{iteration}.sqlite"),
|
||||
)
|
||||
return parse_async_concurrency_metrics(
|
||||
async_db,
|
||||
func.function_name,
|
||||
)
|
||||
|
||||
|
|
@ -176,8 +186,11 @@ async def evaluate_async_candidate( # noqa: PLR0913
|
|||
Returns *(speedup, acceptance_reason)*. *speedup* is ``None``
|
||||
when the candidate is rejected.
|
||||
"""
|
||||
from ..testing._stdout_parsers import ( # noqa: PLC0415
|
||||
calculate_function_throughput_from_test_results,
|
||||
from ..runtime._codeflash_wrap_decorator import ( # noqa: PLC0415
|
||||
get_run_tmp_file,
|
||||
)
|
||||
from ..testing._async_data_parser import ( # noqa: PLC0415
|
||||
calculate_async_throughput,
|
||||
)
|
||||
from ..verification._critic import ( # noqa: PLC0415
|
||||
get_acceptance_reason,
|
||||
|
|
@ -189,8 +202,12 @@ async def evaluate_async_candidate( # noqa: PLR0913
|
|||
from ._test_orchestrator import build_test_env # noqa: PLC0415
|
||||
|
||||
func = fn_input.function
|
||||
candidate_throughput = calculate_function_throughput_from_test_results(
|
||||
bench_results,
|
||||
iteration = 0
|
||||
async_db = get_run_tmp_file(
|
||||
Path(f"async_results_{iteration}.sqlite"),
|
||||
)
|
||||
candidate_throughput = calculate_async_throughput(
|
||||
async_db,
|
||||
func.function_name,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,10 @@ import asyncio
|
|||
import atexit
|
||||
import contextvars
|
||||
import gc
|
||||
import io
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
|
@ -36,12 +38,22 @@ F = TypeVar("F", bound=Callable[..., Any])
|
|||
|
||||
|
||||
def get_run_tmp_file(file_path: Path) -> Path:
|
||||
"""Return a path inside a persistent per-run temporary directory."""
|
||||
"""Return a path inside a persistent per-run temporary directory.
|
||||
|
||||
Uses ``CODEFLASH_RUN_TMPDIR`` if set (subprocess case), otherwise
|
||||
creates a new tmpdir and exports the env var so child processes
|
||||
share the same directory.
|
||||
"""
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory( # type: ignore[attr-defined]
|
||||
prefix="codeflash_"
|
||||
)
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined]
|
||||
env_dir = os.environ.get("CODEFLASH_RUN_TMPDIR")
|
||||
if env_dir and Path(env_dir).is_dir():
|
||||
get_run_tmp_file.tmpdir = env_dir # type: ignore[attr-defined]
|
||||
else:
|
||||
td = TemporaryDirectory(prefix="codeflash_")
|
||||
get_run_tmp_file.tmpdir = td.name # type: ignore[attr-defined]
|
||||
get_run_tmp_file.td_ref = td # type: ignore[attr-defined]
|
||||
os.environ["CODEFLASH_RUN_TMPDIR"] = td.name
|
||||
return Path(get_run_tmp_file.tmpdir) / file_path # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
||||
|
|
@ -82,7 +94,8 @@ _CREATE_TABLE_SQL = (
|
|||
"verification_type TEXT, "
|
||||
"sequential_time_ns INTEGER, "
|
||||
"concurrent_time_ns INTEGER, "
|
||||
"concurrency_factor INTEGER"
|
||||
"concurrency_factor INTEGER, "
|
||||
"stdout TEXT"
|
||||
")"
|
||||
)
|
||||
|
||||
|
|
@ -155,6 +168,9 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
conn, cur = _get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
captured_stdout = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = captured_stdout
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
|
|
@ -167,6 +183,9 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
stdout_text = captured_stdout.getvalue()
|
||||
|
||||
pickled = (
|
||||
pickle.dumps(exception)
|
||||
|
|
@ -175,7 +194,7 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
)
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
|
|
@ -190,6 +209,7 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
stdout_text,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
|
@ -254,7 +274,7 @@ def codeflash_performance_async(func: F) -> F:
|
|||
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
|
|
@ -269,6 +289,7 @@ def codeflash_performance_async(func: F) -> F:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
|
@ -324,7 +345,7 @@ def codeflash_concurrency_async(func: F) -> F:
|
|||
|
||||
cur.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
|
|
@ -339,6 +360,7 @@ def codeflash_concurrency_async(func: F) -> F:
|
|||
sequential_time,
|
||||
concurrent_time,
|
||||
concurrency_factor,
|
||||
None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
|
|
|||
|
|
@ -1,42 +1,47 @@
|
|||
"""Async wrapper decorators for behavior, performance, and concurrency testing."""
|
||||
"""Shared runtime helpers used by sync instrumentation.
|
||||
|
||||
Async decorators have moved to ``_codeflash_async_decorators.py``.
|
||||
This module retains ``VerificationType``, ``get_run_tmp_file``, and
|
||||
``extract_test_context_from_env`` which are still used by the sync
|
||||
capture path (``_codeflash_capture.py``) and multiple test/analysis
|
||||
modules.
|
||||
"""
|
||||
|
||||
# ruff: noqa: T201, BLE001
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
import dill as pickle
|
||||
|
||||
|
||||
class VerificationType(
|
||||
str, Enum
|
||||
): # moved from codeflash/verification/codeflash_capture.py
|
||||
class VerificationType(str, Enum):
|
||||
"""Type of correctness verification for captured test data."""
|
||||
|
||||
FUNCTION_CALL = "function_call" # Correctness verification for a test function, checks input values and output values)
|
||||
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
|
||||
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
FUNCTION_CALL = "function_call"
|
||||
INIT_STATE_FTO = "init_state_fto"
|
||||
INIT_STATE_HELPER = "init_state_helper"
|
||||
|
||||
|
||||
def get_run_tmp_file(
|
||||
file_path: Path,
|
||||
) -> Path: # moved from codeflash/code_utils/code_utils.py
|
||||
"""Return a path inside a persistent per-run temporary directory."""
|
||||
) -> Path:
|
||||
"""Return a path inside a persistent per-run temporary directory.
|
||||
|
||||
Uses ``CODEFLASH_RUN_TMPDIR`` if set (subprocess case), otherwise
|
||||
creates a new tmpdir and exports the env var so child processes
|
||||
share the same directory.
|
||||
"""
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") # type: ignore[attr-defined]
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path # type: ignore[attr-defined]
|
||||
env_dir = os.environ.get("CODEFLASH_RUN_TMPDIR")
|
||||
if env_dir and Path(env_dir).is_dir():
|
||||
get_run_tmp_file.tmpdir = env_dir # type: ignore[attr-defined]
|
||||
else:
|
||||
td = TemporaryDirectory(prefix="codeflash_")
|
||||
get_run_tmp_file.tmpdir = td.name # type: ignore[attr-defined]
|
||||
get_run_tmp_file.td_ref = td # type: ignore[attr-defined]
|
||||
os.environ["CODEFLASH_RUN_TMPDIR"] = td.name
|
||||
return Path(get_run_tmp_file.tmpdir) / file_path # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
||||
|
|
@ -51,191 +56,3 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
|||
raise RuntimeError( # noqa: TRY003
|
||||
"Test context environment variables not set - ensure tests are run through codeflash test runner" # noqa: EM101
|
||||
)
|
||||
|
||||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
"""Decorator capturing async function return values and timing for behavioral tests."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Await the wrapped coroutine and record its result to SQLite."""
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = (
|
||||
extract_test_context_from_env()
|
||||
)
|
||||
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {} # type: ignore[attr-defined]
|
||||
if test_id in async_wrapper.index: # type: ignore[attr-defined]
|
||||
async_wrapper.index[test_id] += 1 # type: ignore[attr-defined]
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0 # type: ignore[attr-defined]
|
||||
|
||||
codeflash_test_index = async_wrapper.index[test_id] # type: ignore[attr-defined]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"test_return_values_{iteration}.sqlite")
|
||||
)
|
||||
codeflash_con = sqlite3.connect(db_path)
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
|
||||
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
|
||||
"runtime INTEGER, return_value BLOB, verification_type TEXT, cpu_runtime INTEGER)"
|
||||
)
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(
|
||||
*args, **kwargs
|
||||
) # coroutine creation has some overhead, though it is very small
|
||||
counter = loop.time()
|
||||
return_value = (
|
||||
await ret
|
||||
) # let's measure the actual execution time of the code
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
print(f"!######{test_stdout_tag}######!")
|
||||
pickled_return_value = (
|
||||
pickle.dumps(exception)
|
||||
if exception
|
||||
else pickle.dumps((args, kwargs, return_value))
|
||||
)
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
0,
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
codeflash_con.close()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def codeflash_performance_async(func: F) -> F:
|
||||
"""Decorator measuring async function execution time for performance tests."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Await the wrapped coroutine and emit its timing via stdout."""
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
|
||||
test_module_name, test_class_name, test_name = (
|
||||
extract_test_context_from_env()
|
||||
)
|
||||
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {} # type: ignore[attr-defined]
|
||||
if test_id in async_wrapper.index: # type: ignore[attr-defined]
|
||||
async_wrapper.index[test_id] += 1 # type: ignore[attr-defined]
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0 # type: ignore[attr-defined]
|
||||
|
||||
codeflash_test_index = async_wrapper.index[test_id] # type: ignore[attr-defined]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func: F) -> F:
|
||||
"""Measures concurrent vs sequential execution performance for async functions."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Run sequential then concurrent executions and emit timing metrics."""
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(
|
||||
os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10")
|
||||
)
|
||||
|
||||
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
|
||||
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
|
||||
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
|
||||
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
|
||||
|
||||
# Phase 1: Sequential execution timing
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
# Phase 2: Concurrent execution timing
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
# Output parseable metrics
|
||||
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
|
||||
print(
|
||||
f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper # type: ignore[return-value]
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ _BEHAVIOR_QUERY = (
|
|||
"SELECT test_module_path, test_class_name,"
|
||||
" test_function_name, function_getting_tested,"
|
||||
" loop_index, invocation_id, wall_time_ns,"
|
||||
" return_value, verification_type"
|
||||
" return_value, verification_type, stdout"
|
||||
" FROM async_results"
|
||||
" WHERE mode = 'behavior'"
|
||||
)
|
||||
|
|
@ -109,6 +109,7 @@ def _process_behavior_row_inner(
|
|||
invocation_id = val[5]
|
||||
wall_time_ns = val[6]
|
||||
verification_type = val[8]
|
||||
stdout_text = val[9] if len(val) > 9 else None
|
||||
|
||||
test_file_path = file_path_from_module_name(
|
||||
test_module_path, # type: ignore[arg-type]
|
||||
|
|
@ -173,6 +174,7 @@ def _process_behavior_row_inner(
|
|||
if verification_type
|
||||
else None
|
||||
),
|
||||
stdout=stdout_text or None,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..runtime._codeflash_wrap_decorator import get_run_tmp_file
|
||||
from ._async_data_parser import parse_async_behavior_results
|
||||
from ._data_parsers import parse_sqlite_test_results
|
||||
from ._result_merger import merge_test_results
|
||||
from ._stdout_parsers import parse_test_failures_from_stdout
|
||||
|
|
@ -50,6 +51,20 @@ def parse_test_results(
|
|||
sql_file, test_files, test_config
|
||||
)
|
||||
|
||||
# Parse async SQLite results
|
||||
async_sql_file = get_run_tmp_file(
|
||||
Path(f"async_results_{optimization_iteration}.sqlite"),
|
||||
)
|
||||
if async_sql_file.exists():
|
||||
async_results = parse_async_behavior_results(
|
||||
async_sql_file,
|
||||
test_files,
|
||||
test_config,
|
||||
)
|
||||
for inv in async_results:
|
||||
data_results.test_results.append(inv)
|
||||
async_sql_file.unlink(missing_ok=True)
|
||||
|
||||
# Clean up deprecated binary pickle file if present
|
||||
bin_file = get_run_tmp_file(
|
||||
Path(f"test_return_values_{optimization_iteration}.bin"),
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def _merge_single_xml(
|
|||
if data_result.verification_type
|
||||
else None
|
||||
),
|
||||
stdout=xml_result.stdout,
|
||||
stdout=data_result.stdout or xml_result.stdout,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,8 @@
|
|||
"""Stdout-based parsing: test failures and performance/concurrency metrics."""
|
||||
"""Stdout-based parsing: test failure extraction from pytest output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..benchmarking.models import ConcurrencyMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import TestResults
|
||||
|
||||
TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$")
|
||||
|
||||
|
|
@ -78,82 +72,3 @@ def _collect_failures(
|
|||
failures[current_name] = "".join(current_lines)
|
||||
|
||||
return failures
|
||||
|
||||
|
||||
# -- Performance and concurrency metrics --
|
||||
|
||||
_perf_start_pattern = re.compile(
|
||||
r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!",
|
||||
)
|
||||
_perf_end_pattern = re.compile(
|
||||
r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!",
|
||||
)
|
||||
|
||||
_concurrency_pattern = re.compile(
|
||||
r"!@######CONC:"
|
||||
r"([^:]*):([^:]*):([^:]*):([^:]*):([^:]*)"
|
||||
r":(\d+):(\d+):(\d+)######@!",
|
||||
)
|
||||
|
||||
|
||||
def calculate_function_throughput_from_test_results(
|
||||
test_results: TestResults,
|
||||
function_name: str,
|
||||
) -> int:
|
||||
"""Count completed function executions from performance stdout markers."""
|
||||
start_matches = _perf_start_pattern.findall(
|
||||
test_results.perf_stdout or "",
|
||||
)
|
||||
end_matches = _perf_end_pattern.findall(
|
||||
test_results.perf_stdout or "",
|
||||
)
|
||||
|
||||
end_matches_truncated = [m[:5] for m in end_matches]
|
||||
end_matches_set = set(end_matches_truncated)
|
||||
|
||||
count = 0
|
||||
expected_fn_idx = 2
|
||||
for start_match in start_matches:
|
||||
if (
|
||||
start_match in end_matches_set
|
||||
and len(start_match) > expected_fn_idx
|
||||
and start_match[expected_fn_idx] == function_name
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def parse_concurrency_metrics(
|
||||
test_results: TestResults,
|
||||
function_name: str,
|
||||
) -> ConcurrencyMetrics | None:
|
||||
"""Parse concurrency benchmark results from test output."""
|
||||
if not test_results.perf_stdout:
|
||||
return None
|
||||
|
||||
matches = _concurrency_pattern.findall(test_results.perf_stdout)
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
expected_groups = 8
|
||||
total_seq, total_conc, factor, count = 0, 0, 0, 0
|
||||
for match in matches:
|
||||
if len(match) >= expected_groups and match[3] == function_name:
|
||||
total_seq += int(match[5])
|
||||
total_conc += int(match[6])
|
||||
factor = int(match[7])
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
avg_seq = total_seq / count
|
||||
avg_conc = total_conc / count
|
||||
ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0
|
||||
|
||||
return ConcurrencyMetrics(
|
||||
sequential_time_ns=int(avg_seq),
|
||||
concurrent_time_ns=int(avg_conc),
|
||||
concurrency_factor=factor,
|
||||
concurrency_ratio=ratio,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
|
@ -23,6 +24,13 @@ _PER_FILE_TIMEOUT = 60
|
|||
_MAX_TIMEOUT = 600
|
||||
|
||||
|
||||
def _propagate_tmpdir(env: dict[str, str]) -> None:
|
||||
"""Ensure CODEFLASH_RUN_TMPDIR is in the subprocess env."""
|
||||
tmpdir = os.environ.get("CODEFLASH_RUN_TMPDIR")
|
||||
if tmpdir:
|
||||
env["CODEFLASH_RUN_TMPDIR"] = tmpdir
|
||||
|
||||
|
||||
def _base_pytest_args(rootdir: Path | None, cwd: Path) -> list[str]:
|
||||
"""Common pytest args shared across all test runner functions."""
|
||||
return [
|
||||
|
|
@ -48,6 +56,8 @@ def execute_test_subprocess(
|
|||
timeout: int = 600,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Execute a subprocess with the given command list."""
|
||||
if env is not None:
|
||||
_propagate_tmpdir(env)
|
||||
log.debug(
|
||||
"executing test run with command: %s",
|
||||
" ".join(cmd_list),
|
||||
|
|
@ -355,6 +365,8 @@ async def async_execute_test_subprocess(
|
|||
timeout: int = 600,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Execute a subprocess asynchronously."""
|
||||
if env is not None:
|
||||
_propagate_tmpdir(env)
|
||||
log.debug(
|
||||
"executing async test run with command: %s",
|
||||
" ".join(cmd_list),
|
||||
|
|
|
|||
|
|
@ -1,343 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash_python.benchmarking.models import ConcurrencyMetrics
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import (
|
||||
codeflash_concurrency_async,
|
||||
)
|
||||
from codeflash_python.testing._stdout_parsers import parse_concurrency_metrics
|
||||
from codeflash_python.testing.models import TestResults
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32", reason="pending support for asyncio on windows"
|
||||
)
|
||||
class TestConcurrencyAsyncDecorator:
|
||||
"""Integration tests for codeflash_concurrency_async decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def concurrency_env_setup(self, request):
|
||||
"""Set up environment variables for concurrency testing."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestConcurrencyAsyncDecorator",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CONCURRENCY_FACTOR": "5", # Use smaller factor for faster tests
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_nonblocking_function(
|
||||
self, concurrency_env_setup, capsys
|
||||
):
|
||||
"""Test that non-blocking async functions show high concurrency ratio."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def nonblocking_sleep(duration: float) -> str:
|
||||
await asyncio.sleep(duration)
|
||||
return "done"
|
||||
|
||||
result = await nonblocking_sleep(0.01)
|
||||
|
||||
assert result == "done"
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Verify the output format
|
||||
assert "!@######CONC:" in output
|
||||
assert "######@!" in output
|
||||
|
||||
# Parse the output manually to verify format
|
||||
lines = [
|
||||
line
|
||||
for line in output.strip().split("\n")
|
||||
if "!@######CONC:" in line
|
||||
]
|
||||
assert len(lines) == 1
|
||||
|
||||
line = lines[0]
|
||||
# Format: !@######CONC:{test_module}:{test_class}:{test_function}:{function_name}:{loop_index}:{seq_time}:{conc_time}:{factor}######@!
|
||||
assert "nonblocking_sleep" in line
|
||||
assert ":5######@!" in line # concurrency factor
|
||||
|
||||
# Extract timing values
|
||||
parts = (
|
||||
line.replace("!@######CONC:", "")
|
||||
.replace("######@!", "")
|
||||
.split(":")
|
||||
)
|
||||
# parts should be: [test_module, test_class, test_function, function_name, loop_index, seq_time, conc_time, factor]
|
||||
assert len(parts) == 8
|
||||
|
||||
seq_time = int(parts[5])
|
||||
conc_time = int(parts[6])
|
||||
factor = int(parts[7])
|
||||
|
||||
assert seq_time > 0
|
||||
assert conc_time > 0
|
||||
assert factor == 5
|
||||
|
||||
# For non-blocking async, concurrent time should be much less than sequential
|
||||
# Sequential runs 5 iterations of 10ms = ~50ms
|
||||
# Concurrent runs 5 iterations in parallel = ~10ms
|
||||
# So ratio should be around 5 (with some overhead tolerance)
|
||||
ratio = seq_time / conc_time if conc_time > 0 else 1.0
|
||||
assert ratio > 2.0, (
|
||||
f"Non-blocking function should have ratio > 2.0, got {ratio}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_blocking_function(
|
||||
self, concurrency_env_setup, capsys
|
||||
):
|
||||
"""Test that blocking functions show low concurrency ratio (~1.0)."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def blocking_sleep(duration: float) -> str:
|
||||
time.sleep(duration) # Blocking sleep
|
||||
return "done"
|
||||
|
||||
result = await blocking_sleep(0.005) # 5ms blocking
|
||||
|
||||
assert result == "done"
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
assert "!@######CONC:" in output
|
||||
|
||||
lines = [
|
||||
line
|
||||
for line in output.strip().split("\n")
|
||||
if "!@######CONC:" in line
|
||||
]
|
||||
assert len(lines) == 1
|
||||
|
||||
line = lines[0]
|
||||
parts = (
|
||||
line.replace("!@######CONC:", "")
|
||||
.replace("######@!", "")
|
||||
.split(":")
|
||||
)
|
||||
assert len(parts) == 8
|
||||
|
||||
seq_time = int(parts[5])
|
||||
conc_time = int(parts[6])
|
||||
|
||||
# For blocking code, sequential and concurrent times should be similar
|
||||
# Because time.sleep blocks the entire event loop
|
||||
ratio = seq_time / conc_time if conc_time > 0 else 1.0
|
||||
# Blocking code should have ratio close to 1.0 (within reasonable tolerance)
|
||||
assert ratio < 2.0, (
|
||||
f"Blocking function should have ratio < 2.0, got {ratio}"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_with_computation(
|
||||
self, concurrency_env_setup, capsys
|
||||
):
|
||||
"""Test concurrency with CPU-bound computation."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def compute_intensive(n: int) -> int:
|
||||
# CPU-bound work (blocked by GIL in concurrent execution)
|
||||
total = 0
|
||||
for i in range(n):
|
||||
total += i * i
|
||||
return total
|
||||
|
||||
result = await compute_intensive(10000)
|
||||
|
||||
assert result == sum(i * i for i in range(10000))
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
assert "!@######CONC:" in output
|
||||
assert "compute_intensive" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32", reason="pending support for asyncio on windows"
|
||||
)
|
||||
class TestParseConcurrencyMetrics:
|
||||
"""Integration tests for parse_concurrency_metrics function."""
|
||||
|
||||
def test_parse_concurrency_metrics_from_real_output(self):
|
||||
"""Test parsing concurrency metrics from simulated stdout."""
|
||||
# Simulate stdout from codeflash_concurrency_async decorator
|
||||
perf_stdout = """Some other output
|
||||
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
|
||||
More output here
|
||||
"""
|
||||
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_async_func")
|
||||
|
||||
assert metrics is not None
|
||||
assert isinstance(metrics, ConcurrencyMetrics)
|
||||
assert metrics.sequential_time_ns == 50000000
|
||||
assert metrics.concurrent_time_ns == 10000000
|
||||
assert metrics.concurrency_factor == 5
|
||||
assert metrics.concurrency_ratio == 5.0 # 50M / 10M = 5.0
|
||||
|
||||
def test_parse_concurrency_metrics_multiple_entries(self):
|
||||
"""Test parsing when multiple concurrency entries exist."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:target_func:1:40000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "target_func")
|
||||
|
||||
assert metrics is not None
|
||||
# Should average the two entries for target_func
|
||||
# (40M + 60M) / 2 = 50M seq, (10M + 10M) / 2 = 10M conc
|
||||
assert metrics.sequential_time_ns == 50000000
|
||||
assert metrics.concurrent_time_ns == 10000000
|
||||
assert metrics.concurrency_ratio == 5.0
|
||||
|
||||
def test_parse_concurrency_metrics_no_match(self):
|
||||
"""Test parsing when function name doesn't match."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(test_results=[], perf_stdout=perf_stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
def test_parse_concurrency_metrics_empty_stdout(self):
|
||||
"""Test parsing with empty stdout."""
|
||||
test_results = TestResults(test_results=[], perf_stdout="")
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
def test_parse_concurrency_metrics_none_stdout(self):
|
||||
"""Test parsing with None stdout."""
|
||||
test_results = TestResults(test_results=[], perf_stdout=None)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32", reason="pending support for asyncio on windows"
|
||||
)
|
||||
class TestConcurrencyRatioComparison:
|
||||
"""Test comparing blocking vs non-blocking concurrency ratios."""
|
||||
|
||||
@pytest.fixture
|
||||
def comparison_env_setup(self, request):
|
||||
"""Set up environment variables for comparison testing."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestConcurrencyRatioComparison",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CONCURRENCY_FACTOR": "10",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_vs_nonblocking_comparison(
|
||||
self, comparison_env_setup, capsys
|
||||
):
|
||||
"""Compare concurrency ratios between blocking and non-blocking implementations."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def blocking_impl() -> str:
|
||||
time.sleep(0.002) # 2ms blocking
|
||||
return "blocking"
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def nonblocking_impl() -> str:
|
||||
await asyncio.sleep(0.002) # 2ms non-blocking
|
||||
return "nonblocking"
|
||||
|
||||
# Run blocking version
|
||||
await blocking_impl()
|
||||
blocking_output = capsys.readouterr().out
|
||||
|
||||
# Run non-blocking version
|
||||
await nonblocking_impl()
|
||||
nonblocking_output = capsys.readouterr().out
|
||||
|
||||
# Parse blocking metrics
|
||||
blocking_line = [
|
||||
l for l in blocking_output.split("\n") if "!@######CONC:" in l
|
||||
][0]
|
||||
blocking_parts = (
|
||||
blocking_line.replace("!@######CONC:", "")
|
||||
.replace("######@!", "")
|
||||
.split(":")
|
||||
)
|
||||
blocking_seq = int(blocking_parts[5])
|
||||
blocking_conc = int(blocking_parts[6])
|
||||
blocking_ratio = (
|
||||
blocking_seq / blocking_conc if blocking_conc > 0 else 1.0
|
||||
)
|
||||
|
||||
# Parse non-blocking metrics
|
||||
nonblocking_line = [
|
||||
l for l in nonblocking_output.split("\n") if "!@######CONC:" in l
|
||||
][0]
|
||||
nonblocking_parts = (
|
||||
nonblocking_line.replace("!@######CONC:", "")
|
||||
.replace("######@!", "")
|
||||
.split(":")
|
||||
)
|
||||
nonblocking_seq = int(nonblocking_parts[5])
|
||||
nonblocking_conc = int(nonblocking_parts[6])
|
||||
nonblocking_ratio = (
|
||||
nonblocking_seq / nonblocking_conc if nonblocking_conc > 0 else 1.0
|
||||
)
|
||||
|
||||
# Non-blocking should have significantly higher concurrency ratio
|
||||
assert nonblocking_ratio > blocking_ratio, (
|
||||
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than blocking ratio ({blocking_ratio:.2f})"
|
||||
)
|
||||
|
||||
# The difference should be substantial (non-blocking should be at least 2x better)
|
||||
ratio_improvement = (
|
||||
nonblocking_ratio / blocking_ratio if blocking_ratio > 0 else 0
|
||||
)
|
||||
assert ratio_improvement > 2.0, (
|
||||
f"Non-blocking should show >2x improvement in concurrency ratio, got {ratio_improvement:.2f}x"
|
||||
)
|
||||
|
|
@ -52,7 +52,7 @@ def _create_async_db(
|
|||
for row in rows:
|
||||
conn.execute(
|
||||
"INSERT INTO async_results VALUES "
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
row,
|
||||
)
|
||||
conn.commit()
|
||||
|
|
@ -100,6 +100,7 @@ class TestParseAsyncBehaviorResults:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -139,6 +140,7 @@ class TestParseAsyncBehaviorResults:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -191,6 +193,7 @@ class TestCalculateAsyncThroughput:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
|
|
@ -206,6 +209,7 @@ class TestCalculateAsyncThroughput:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
|
|
@ -221,6 +225,7 @@ class TestCalculateAsyncThroughput:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
|
|
@ -236,6 +241,7 @@ class TestCalculateAsyncThroughput:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -281,6 +287,7 @@ class TestParseAsyncConcurrencyMetrics:
|
|||
100_000,
|
||||
50_000,
|
||||
10,
|
||||
None,
|
||||
),
|
||||
(
|
||||
"mod",
|
||||
|
|
@ -296,6 +303,7 @@ class TestParseAsyncConcurrencyMetrics:
|
|||
200_000,
|
||||
100_000,
|
||||
10,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -326,6 +334,7 @@ class TestParseAsyncConcurrencyMetrics:
|
|||
100_000,
|
||||
50_000,
|
||||
10,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -352,6 +361,7 @@ class TestParseAsyncConcurrencyMetrics:
|
|||
100_000,
|
||||
0,
|
||||
5,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -406,6 +416,7 @@ class TestParseAsyncBehaviorEdgeCases:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -449,6 +460,7 @@ class TestParseAsyncBehaviorEdgeCases:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -485,6 +497,7 @@ class TestParseAsyncBehaviorEdgeCases:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -517,6 +530,7 @@ class TestParseAsyncBehaviorEdgeCases:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
@ -554,6 +568,7 @@ class TestParseAsyncBehaviorEdgeCases:
|
|||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ class TestBehaviorAsync:
|
|||
async def test_no_stdout_output(
|
||||
self, env_setup, async_db_path, capsys
|
||||
) -> None:
|
||||
"""Behavior decorator emits no stdout."""
|
||||
"""Behavior decorator does not leak stdout to outer scope."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def noop() -> int:
|
||||
|
|
@ -238,6 +238,28 @@ class TestBehaviorAsync:
|
|||
captured = capsys.readouterr()
|
||||
assert "" == captured.out
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_captures_stdout_in_sqlite(
|
||||
self, env_setup, async_db_path
|
||||
) -> None:
|
||||
"""Behavior decorator captures print output into the stdout column."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def greeter(name: str) -> str:
|
||||
print(f"hello {name}")
|
||||
return f"hi {name}"
|
||||
|
||||
_codeflash_call_site.set("0")
|
||||
await greeter("world")
|
||||
_close_all_connections()
|
||||
|
||||
con = sqlite3.connect(async_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT stdout FROM async_results")
|
||||
row = cur.fetchone()
|
||||
assert "hello world\n" == row[0]
|
||||
con.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
|
|
@ -482,7 +504,7 @@ class TestSchemaValidation:
|
|||
"""Validate the async_results SQLite schema."""
|
||||
|
||||
def test_table_columns(self, tmp_path) -> None:
|
||||
"""async_results table has exactly 13 columns."""
|
||||
"""async_results table has exactly 14 columns."""
|
||||
db_path = tmp_path / "schema_test.sqlite"
|
||||
conn, cur = _get_async_db(db_path)
|
||||
cur.execute("PRAGMA table_info(async_results)")
|
||||
|
|
@ -501,6 +523,7 @@ class TestSchemaValidation:
|
|||
"sequential_time_ns",
|
||||
"concurrent_time_ns",
|
||||
"concurrency_factor",
|
||||
"stdout",
|
||||
]
|
||||
actual_names = [col[1] for col in columns]
|
||||
assert expected_names == actual_names
|
||||
|
|
|
|||
|
|
@ -1,333 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
import pytest
|
||||
|
||||
from codeflash_python.runtime._codeflash_capture import VerificationType
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
codeflash_performance_async,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32", reason="pending support for asyncio on windows"
|
||||
)
|
||||
class TestAsyncWrapperSQLiteValidation:
|
||||
@pytest.fixture
|
||||
def test_env_setup(self, request):
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "0",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self, test_env_setup):
|
||||
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
|
||||
from codeflash_python.testing._instrumentation import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"test_return_values_{iteration}.sqlite")
|
||||
)
|
||||
|
||||
yield db_path
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_basic_function(
|
||||
self, test_env_setup, temp_db_path
|
||||
):
|
||||
@codeflash_behavior_async
|
||||
async def simple_async_add(a: int, b: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return a + b
|
||||
|
||||
os.environ["CODEFLASH_CURRENT_LINE_ID"] = "simple_async_add_59"
|
||||
result = await simple_async_add(5, 3)
|
||||
|
||||
assert result == 8
|
||||
|
||||
assert temp_db_path.exists()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'"
|
||||
)
|
||||
assert cur.fetchone() is not None
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
|
||||
(
|
||||
test_module_path,
|
||||
test_class_name,
|
||||
test_function_name,
|
||||
function_getting_tested,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
cpu_runtime,
|
||||
) = row
|
||||
|
||||
assert test_module_path == __name__
|
||||
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
|
||||
assert test_function_name == "test_behavior_async_basic_function"
|
||||
assert function_getting_tested == "simple_async_add"
|
||||
assert loop_index == 1
|
||||
# Line ID will be the actual line number from the source code, not a simple counter
|
||||
assert iteration_id.startswith(
|
||||
"simple_async_add_"
|
||||
) and iteration_id.endswith("_0")
|
||||
assert runtime > 0
|
||||
assert verification_type == VerificationType.FUNCTION_CALL.value
|
||||
assert isinstance(cpu_runtime, int)
|
||||
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, return_val = unpickled_data
|
||||
|
||||
assert args == (5, 3)
|
||||
assert kwargs == {}
|
||||
assert return_val == 8
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_exception_handling(
|
||||
self, test_env_setup, temp_db_path
|
||||
):
|
||||
@codeflash_behavior_async
|
||||
async def async_divide(a: int, b: int) -> float:
|
||||
await asyncio.sleep(0.001)
|
||||
if b == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
return a / b
|
||||
|
||||
result = await async_divide(10, 2)
|
||||
assert result == 5.0
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot divide by zero"):
|
||||
await async_divide(10, 0)
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 2
|
||||
|
||||
success_row = rows[0]
|
||||
success_data = pickle.loads(success_row[7]) # return_value_blob
|
||||
args, kwargs, return_val = success_data
|
||||
assert args == (10, 2)
|
||||
assert return_val == 5.0
|
||||
|
||||
# Check exception record
|
||||
exception_row = rows[1]
|
||||
exception_data = pickle.loads(exception_row[7]) # return_value_blob
|
||||
assert isinstance(exception_data, ValueError)
|
||||
assert str(exception_data) == "Cannot divide by zero"
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_async_no_database_storage(
|
||||
self, test_env_setup, temp_db_path, capsys
|
||||
):
|
||||
"""Test performance async decorator doesn't store to database."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_multiply(a: int, b: int) -> int:
|
||||
"""Async function for performance testing."""
|
||||
await asyncio.sleep(0.002)
|
||||
return a * b
|
||||
|
||||
result = await async_multiply(4, 7)
|
||||
|
||||
assert result == 28
|
||||
|
||||
assert not temp_db_path.exists()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lines = captured.out.strip().split("\n")
|
||||
|
||||
assert len([line for line in output_lines if "!$######" in line]) == 1
|
||||
assert (
|
||||
len(
|
||||
[
|
||||
line
|
||||
for line in output_lines
|
||||
if "!######" in line and "######!" in line
|
||||
]
|
||||
)
|
||||
== 1
|
||||
)
|
||||
|
||||
closing_tag = [
|
||||
line
|
||||
for line in output_lines
|
||||
if "!######" in line and "######!" in line
|
||||
][0]
|
||||
assert "async_multiply" in closing_tag
|
||||
|
||||
timing_part = closing_tag.split(":")[-1].replace("######!", "")
|
||||
timing_value = int(timing_part)
|
||||
assert timing_value > 0 # Should have positive timing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
|
||||
@codeflash_behavior_async
|
||||
async def async_increment(value: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return value + 1
|
||||
|
||||
# Call the function multiple times
|
||||
results = []
|
||||
for i in range(3):
|
||||
result = await async_increment(i)
|
||||
results.append(result)
|
||||
|
||||
assert results == [1, 2, 3]
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute(
|
||||
"SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
actual_ids = [row[0] for row in rows]
|
||||
assert len(actual_ids) == 3
|
||||
|
||||
base_pattern = actual_ids[0].rsplit("_", 1)[
|
||||
0
|
||||
] # e.g., "async_increment_199"
|
||||
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
|
||||
assert actual_ids == expected_pattern
|
||||
|
||||
for i, (_, return_value_blob) in enumerate(rows):
|
||||
args, kwargs, return_val = pickle.loads(return_value_blob)
|
||||
assert args == (i,)
|
||||
assert return_val == i + 1
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_async_function_with_kwargs(
|
||||
self, test_env_setup, temp_db_path
|
||||
):
|
||||
@codeflash_behavior_async
|
||||
async def complex_async_func(
|
||||
pos_arg: str,
|
||||
*args: int,
|
||||
keyword_arg: str = "default",
|
||||
**kwargs: str,
|
||||
) -> dict:
|
||||
await asyncio.sleep(0.001)
|
||||
return {
|
||||
"pos_arg": pos_arg,
|
||||
"args": args,
|
||||
"keyword_arg": keyword_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
result = await complex_async_func(
|
||||
"hello",
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
keyword_arg="custom",
|
||||
extra1="value1",
|
||||
extra2="value2",
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
"pos_arg": "hello",
|
||||
"args": (1, 2, 3),
|
||||
"keyword_arg": "custom",
|
||||
"kwargs": {"extra1": "value1", "extra2": "value2"},
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM test_results")
|
||||
row = cur.fetchone()
|
||||
|
||||
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
|
||||
|
||||
assert stored_args == ("hello", 1, 2, 3)
|
||||
assert stored_kwargs == {
|
||||
"keyword_arg": "custom",
|
||||
"extra1": "value1",
|
||||
"extra2": "value2",
|
||||
}
|
||||
assert stored_result == expected_result
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_schema_validation(
|
||||
self, test_env_setup, temp_db_path
|
||||
):
|
||||
@codeflash_behavior_async
|
||||
async def schema_test_func() -> str:
|
||||
return "schema_test"
|
||||
|
||||
await schema_test_func()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("PRAGMA table_info(test_results)")
|
||||
columns = cur.fetchall()
|
||||
|
||||
expected_columns = [
|
||||
(0, "test_module_path", "TEXT", 0, None, 0),
|
||||
(1, "test_class_name", "TEXT", 0, None, 0),
|
||||
(2, "test_function_name", "TEXT", 0, None, 0),
|
||||
(3, "function_getting_tested", "TEXT", 0, None, 0),
|
||||
(4, "loop_index", "INTEGER", 0, None, 0),
|
||||
(5, "iteration_id", "TEXT", 0, None, 0),
|
||||
(6, "runtime", "INTEGER", 0, None, 0),
|
||||
(7, "return_value", "BLOB", 0, None, 0),
|
||||
(8, "verification_type", "TEXT", 0, None, 0),
|
||||
(9, "cpu_runtime", "INTEGER", 0, None, 0),
|
||||
]
|
||||
|
||||
assert columns == expected_columns
|
||||
con.close()
|
||||
|
|
@ -13,7 +13,6 @@ from codeflash_python.analysis._coverage import (
|
|||
from codeflash_python.benchmarking.models import ConcurrencyMetrics
|
||||
from codeflash_python.context.models import CodeOptimizationContext
|
||||
from codeflash_python.test_discovery.models import TestType
|
||||
from codeflash_python.testing._stdout_parsers import parse_concurrency_metrics
|
||||
from codeflash_python.testing.models import (
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
|
|
@ -882,43 +881,3 @@ def test_concurrency_ratio_display_formatting() -> None:
|
|||
assert display_string == "Concurrency ratio: 0.01x \u2192 0.03x (+200.0%)"
|
||||
|
||||
|
||||
def test_parse_concurrency_metrics() -> None:
|
||||
"""parse_concurrency_metrics extracts metrics from test output."""
|
||||
stdout = (
|
||||
"!@######CONC:test_module:TestClass:test_func:"
|
||||
"my_function:0:10000000:1000000:10######@!\n"
|
||||
"!@######CONC:test_module:TestClass:test_func:"
|
||||
"my_function:1:10000000:1000000:10######@!\n"
|
||||
)
|
||||
test_results = TestResults(perf_stdout=stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_function")
|
||||
assert metrics is not None
|
||||
assert metrics.sequential_time_ns == 10_000_000
|
||||
assert metrics.concurrent_time_ns == 1_000_000
|
||||
assert metrics.concurrency_factor == 10
|
||||
assert metrics.concurrency_ratio == 10.0
|
||||
|
||||
metrics_wrong_func = parse_concurrency_metrics(
|
||||
test_results, "other_function"
|
||||
)
|
||||
assert metrics_wrong_func is None
|
||||
|
||||
empty_results = TestResults(perf_stdout="")
|
||||
metrics_empty = parse_concurrency_metrics(empty_results, "my_function")
|
||||
assert metrics_empty is None
|
||||
|
||||
none_results = TestResults(perf_stdout=None)
|
||||
metrics_none = parse_concurrency_metrics(none_results, "my_function")
|
||||
assert metrics_none is None
|
||||
|
||||
stdout_no_class = (
|
||||
"!@######CONC:test_module::test_func:"
|
||||
"my_function:0:5000000:2500000:10######@!\n"
|
||||
)
|
||||
test_results_no_class = TestResults(perf_stdout=stdout_no_class)
|
||||
metrics_no_class = parse_concurrency_metrics(
|
||||
test_results_no_class, "my_function"
|
||||
)
|
||||
assert metrics_no_class is not None
|
||||
assert metrics_no_class.concurrency_ratio == 2.0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
|
@ -750,272 +749,28 @@ async def test_multiple_calls():
|
|||
assert instrumented_test_code is not None
|
||||
|
||||
assert (
|
||||
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'"
|
||||
"_codeflash_call_site.set('0')"
|
||||
in instrumented_test_code
|
||||
)
|
||||
|
||||
# Count occurrences of each line_id to verify numbering
|
||||
line_id_0_count = instrumented_test_code.count(
|
||||
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'"
|
||||
"_codeflash_call_site.set('0')"
|
||||
)
|
||||
line_id_1_count = instrumented_test_code.count(
|
||||
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'"
|
||||
"_codeflash_call_site.set('1')"
|
||||
)
|
||||
line_id_2_count = instrumented_test_code.count(
|
||||
"os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'"
|
||||
"_codeflash_call_site.set('2')"
|
||||
)
|
||||
|
||||
assert line_id_0_count == 2, (
|
||||
assert 2 == line_id_0_count, (
|
||||
f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
|
||||
)
|
||||
assert line_id_1_count == 1, (
|
||||
assert 1 == line_id_1_count, (
|
||||
f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
|
||||
)
|
||||
assert line_id_2_count == 1, (
|
||||
assert 1 == line_id_2_count, (
|
||||
f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32", reason="pending support for asyncio on windows"
|
||||
)
|
||||
def test_async_behavior_decorator_return_values_and_test_ids():
|
||||
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
)
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def test_async_multiply(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.001) # Small delay to simulate async work
|
||||
return x * y
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_module",
|
||||
"CODEFLASH_TEST_CLASS": None,
|
||||
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "0",
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "2",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
result = asyncio.run(test_async_multiply(6, 7))
|
||||
|
||||
assert result == 42, f"Expected return value 42, got {result}"
|
||||
|
||||
from codeflash_python.testing._instrumentation import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(Path("test_return_values_2.sqlite"))
|
||||
|
||||
# Verify database exists and has data
|
||||
assert db_path.exists(), f"Database file not created at {db_path}"
|
||||
|
||||
# Read and verify database contents
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
|
||||
|
||||
row = rows[0]
|
||||
(
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
cpu_runtime,
|
||||
) = row
|
||||
|
||||
assert test_module == "test_module", (
|
||||
f"Expected test_module 'test_module', got '{test_module}'"
|
||||
)
|
||||
assert test_class is None, (
|
||||
f"Expected test_class None, got '{test_class}'"
|
||||
)
|
||||
assert test_function == "test_async_multiply_function", (
|
||||
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "test_async_multiply", (
|
||||
f"Expected function_name 'test_async_multiply', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
|
||||
assert iteration_id == "0_0", (
|
||||
f"Expected iteration_id '0_0', got '{iteration_id}'"
|
||||
)
|
||||
assert verification_type == "function_call", (
|
||||
f"Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, actual_return_value = unpickled_data
|
||||
|
||||
assert args == (6, 7), f"Expected args (6, 7), got {args}"
|
||||
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
|
||||
|
||||
assert actual_return_value == 42, (
|
||||
f"Expected stored return value 42, got {actual_return_value}"
|
||||
)
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32", reason="pending support for asyncio on windows"
|
||||
)
|
||||
def test_async_decorator_comprehensive_return_values_and_test_ids():
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
)
|
||||
from codeflash_python.testing._instrumentation import get_run_tmp_file
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
|
||||
"""Async function that multiplies x*y then adds z."""
|
||||
await asyncio.sleep(0.001)
|
||||
result = (x * y) + z
|
||||
return result
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
|
||||
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "3",
|
||||
"CODEFLASH_LOOP_INDEX": "2",
|
||||
"CODEFLASH_TEST_ITERATION": "3",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
test_cases = [
|
||||
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
|
||||
{
|
||||
"args": (2, 4),
|
||||
"kwargs": {"z": 10},
|
||||
"expected": 18,
|
||||
}, # (2 * 4) + 10 = 18
|
||||
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_case in test_cases:
|
||||
result = asyncio.run(
|
||||
async_multiply_add(*test_case["args"], **test_case["kwargs"])
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Verify each return value is exactly correct
|
||||
assert result == test_case["expected"], (
|
||||
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
|
||||
)
|
||||
|
||||
db_path = get_run_tmp_file(Path("test_return_values_3.sqlite"))
|
||||
assert db_path.exists(), f"Database not created at {db_path}"
|
||||
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute(
|
||||
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
|
||||
|
||||
for i, (
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) in enumerate(rows):
|
||||
assert test_module == "test_comprehensive_module", (
|
||||
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
|
||||
)
|
||||
assert test_class == "AsyncTestClass", (
|
||||
f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
|
||||
)
|
||||
assert test_function == "test_comprehensive_async_function", (
|
||||
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "async_multiply_add", (
|
||||
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 2, (
|
||||
f"Row {i}: Expected loop_index 2, got {loop_index}"
|
||||
)
|
||||
assert verification_type == "function_call", (
|
||||
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
|
||||
expected_iteration_id = f"3_{i}"
|
||||
assert iteration_id == expected_iteration_id, (
|
||||
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
|
||||
)
|
||||
|
||||
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
|
||||
expected_args = test_cases[i]["args"]
|
||||
expected_kwargs = test_cases[i]["kwargs"]
|
||||
expected_return = test_cases[i]["expected"]
|
||||
|
||||
assert args == expected_args, (
|
||||
f"Row {i}: Expected args {expected_args}, got {args}"
|
||||
)
|
||||
assert kwargs == expected_kwargs, (
|
||||
f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
|
||||
)
|
||||
assert actual_return_value == expected_return, (
|
||||
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
|
||||
)
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
|
|
|||
Loading…
Reference in a new issue