Convert remaining sync test runner callers to async

Replace all sync test runner calls (run_behavioral_tests,
run_benchmarking_tests, run_line_profile_tests) with their async
counterparts throughout the pipeline. This eliminates the
ThreadPoolExecutor in _baseline.py in favor of asyncio.gather(),
and makes _async_bench.py, _candidate_gen.py, and
_function_optimizer.py fully async. Adds async_run_line_profile_tests
and coverage support to async_run_behavioral_tests in _test_runner.py.
This commit is contained in:
Kevin Turcios 2026-04-23 01:46:01 -05:00
parent a292698a1d
commit 92e39d6923
9 changed files with 208 additions and 90 deletions

View file

@ -18,7 +18,7 @@ from codeflash_core import (
from ..context.pipeline import get_code_optimization_context
from ..testing._parse_results import parse_test_results
from ..testing._test_runner import run_benchmarking_tests
from ..testing._test_runner import async_run_benchmarking_tests
if TYPE_CHECKING:
from pathlib import Path
@ -37,7 +37,7 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
def collect_baseline_async_metrics( # noqa: PLR0913
async def collect_baseline_async_metrics( # noqa: PLR0913
baseline: OriginalCodeBaseline,
func: FunctionToOptimize,
code_context: CodeOptimizationContext,
@ -62,7 +62,7 @@ def collect_baseline_async_metrics( # noqa: PLR0913
async_throughput,
)
concurrency_metrics = run_concurrency_benchmark(
concurrency_metrics = await run_concurrency_benchmark(
func=func,
code_context=code_context,
test_env=test_env,
@ -86,7 +86,7 @@ def collect_baseline_async_metrics( # noqa: PLR0913
)
def run_concurrency_benchmark(
async def run_concurrency_benchmark(
func: FunctionToOptimize,
code_context: CodeOptimizationContext,
test_env: dict[str, str],
@ -129,7 +129,7 @@ def run_concurrency_benchmark(
if test_files is None:
return None
bench_xml, bench_result = run_benchmarking_tests(
bench_xml, bench_result = await async_run_benchmarking_tests(
test_files=test_files,
test_env=test_env,
cwd=ctx.project_root,
@ -161,7 +161,7 @@ def run_concurrency_benchmark(
)
def evaluate_async_candidate( # noqa: PLR0913
async def evaluate_async_candidate( # noqa: PLR0913
cid: str,
fn_input: FunctionInput,
baseline: OriginalCodeBaseline,
@ -194,7 +194,7 @@ def evaluate_async_candidate( # noqa: PLR0913
func.function_name,
)
candidate_concurrency = run_concurrency_benchmark(
candidate_concurrency = await run_concurrency_benchmark(
func,
get_code_optimization_context(func, ctx.project_root),
build_test_env(fn_input, ctx.project_root, ctx.test_cfg),

View file

@ -159,7 +159,7 @@ async def run_tests_and_benchmark( # noqa: PLR0913
optimized_runtime = 0
if is_async and evaluate_async_fn is not None:
return evaluate_async_fn(
return await evaluate_async_fn(
cid,
fn_input,
baseline,
@ -427,7 +427,7 @@ def log_evaluation_results(
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
_EvalAsyncFn = Callable[
[
@ -438,5 +438,5 @@ if TYPE_CHECKING:
TestResults,
int, # optimized_runtime
],
float | None,
Awaitable[float | None],
]

View file

@ -32,7 +32,7 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
def generate_candidates(
def generate_candidates( # noqa: PLR0913
ctx: OptimizationContext,
function_trace_id: str,
fn_input: FunctionInput,
@ -99,7 +99,7 @@ def generate_candidates(
return candidates
def generate_lp_candidates( # noqa: C901, PLR0913
async def generate_lp_candidates( # noqa: C901, PLR0913
ctx: OptimizationContext,
function_trace_id: str,
test_files: TestFiles | None,
@ -127,7 +127,7 @@ def generate_lp_candidates( # noqa: C901, PLR0913
)
from ..context.models import CodeStringsMarkdown # noqa: PLC0415
from ..testing._test_runner import ( # noqa: PLC0415
run_line_profile_tests,
async_run_line_profile_tests,
)
from ._function_optimizer import ( # noqa: PLC0415
is_numerical_code,
@ -142,7 +142,7 @@ def generate_lp_candidates( # noqa: C901, PLR0913
for helper in code_context.helper_functions:
hp = _Path(helper.file_path)
if hp not in files_to_restore:
files_to_restore[hp] = hp.read_text("utf-8")
files_to_restore[hp] = hp.read_text("utf-8") # noqa: ASYNC240
baseline_lp_markdown = ""
try:
@ -154,7 +154,7 @@ def generate_lp_candidates( # noqa: C901, PLR0913
if test_files is None:
return [], ""
run_line_profile_tests(
await async_run_line_profile_tests(
test_files=test_files,
test_env=test_env,
cwd=ctx.project_root,

View file

@ -629,7 +629,7 @@ class PythonFunctionOptimizer:
# 4. Establish baseline first so we can send runtime
# data to the AI service for better-informed candidates.
baseline = establish_original_code_baseline(
baseline = await establish_original_code_baseline(
test_files=self.test_files,
test_config=self.ctx.test_cfg,
test_env=test_env,
@ -662,7 +662,7 @@ class PythonFunctionOptimizer:
# 3a. Collect async metrics if function is async.
if func.is_async:
baseline = collect_baseline_async_metrics(
baseline = await collect_baseline_async_metrics(
baseline=baseline,
func=func,
code_context=code_context,
@ -679,7 +679,7 @@ class PythonFunctionOptimizer:
)
# 4b. Line-profiler-guided candidates.
lp_cands, lp_md = generate_lp_candidates(
lp_cands, lp_md = await generate_lp_candidates(
ctx=self.ctx,
function_trace_id=self.function_trace_id,
test_files=self.test_files,
@ -968,7 +968,7 @@ class PythonFunctionOptimizer:
"""
from ._async_bench import evaluate_async_candidate # noqa: PLC0415
def _eval( # noqa: PLR0913
async def _eval( # noqa: PLR0913
cid: str,
fn_input: FunctionInput,
baseline: OriginalCodeBaseline,
@ -976,7 +976,7 @@ class PythonFunctionOptimizer:
bench_results: TestResults,
optimized_runtime: int,
) -> float | None:
speedup, reason = evaluate_async_candidate(
speedup, reason = await evaluate_async_candidate(
cid=cid,
fn_input=fn_input,
baseline=baseline,

View file

@ -37,7 +37,7 @@ def _build_capabilities() -> dict[str, object]:
parse_test_results,
)
from ..testing._test_runner import ( # noqa: PLC0415
run_behavioral_tests,
async_run_behavioral_tests,
)
from ..verification._verification import ( # noqa: PLC0415
compare_test_results,
@ -49,7 +49,7 @@ def _build_capabilities() -> dict[str, object]:
"discover_functions": discover_functions,
"extract_context": get_code_optimization_context,
"replace_code": replace_functions_in_file,
"run_tests": run_behavioral_tests,
"run_tests": async_run_behavioral_tests,
"parse_results": parse_test_results,
"compare_results": compare_test_results,
# Optional capabilities.
@ -78,10 +78,10 @@ def _lazy_generate_tests() -> object:
def _lazy_run_benchmarks() -> object:
"""Placeholder — actual binding happens at call site."""
from ..testing._test_runner import ( # noqa: PLC0415
run_benchmarking_tests,
async_run_benchmarking_tests,
)
return run_benchmarking_tests
return async_run_benchmarking_tests
@attrs.frozen

View file

@ -2,6 +2,9 @@
from ._parse_results import parse_test_results
from ._test_runner import (
async_run_behavioral_tests,
async_run_benchmarking_tests,
async_run_line_profile_tests,
run_behavioral_tests,
run_benchmarking_tests,
run_line_profile_tests,
@ -22,6 +25,9 @@ __all__ = [
"TestFile",
"TestFiles",
"TestResults",
"async_run_behavioral_tests",
"async_run_benchmarking_tests",
"async_run_line_profile_tests",
"parse_test_results",
"run_behavioral_tests",
"run_benchmarking_tests",

View file

@ -398,19 +398,16 @@ async def async_run_behavioral_tests( # noqa: PLR0913
cwd: Path,
pytest_cmd: str = "pytest",
timeout: int | None = None,
enable_coverage: bool = False, # noqa: FBT001, FBT002
rootdir: Path | None = None,
result_file_name: str = "pytest_results.xml",
) -> tuple[
Path,
subprocess.CompletedProcess[str],
None,
None,
Path | None,
Path | None,
]:
"""Async version of :func:`run_behavioral_tests`.
Coverage is never needed during parallel candidate evaluation,
so the *enable_coverage* parameter is omitted.
"""
"""Async version of :func:`run_behavioral_tests` with coverage support."""
blocklisted_plugins = [
"benchmark",
"codspeed",
@ -461,21 +458,64 @@ async def async_run_behavioral_tests( # noqa: PLR0913
pytest_test_env["PYTEST_PLUGINS"] = (
"codeflash_python.testing._pytest_plugin"
)
coverage_database_file: Path | None = None
coverage_config_file: Path | None = None
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
subprocess_timeout = _subprocess_timeout(len(test_file_paths))
results = await async_execute_test_subprocess(
pytest_cmd_list
+ common_args
+ blocklist_args
+ result_args
+ test_file_paths,
cwd=cwd,
env=pytest_test_env,
timeout=subprocess_timeout,
)
if enable_coverage:
from ..analysis._coverage import ( # noqa: PLC0415
prepare_coverage_files,
)
from ..verification._baseline import ( # noqa: PLC0415
jit_disabled_env,
)
return result_file_path, results, None, None
coverage_database_file, coverage_config_file = prepare_coverage_files()
pytest_test_env.update(jit_disabled_env())
coverage_cmd = [
sys.executable,
"-m",
"coverage",
"run",
f"--rcfile={coverage_config_file.as_posix()}",
"-m",
*shlex.split(pytest_cmd),
]
cov_blocklist = [
f"-p no:{p}" for p in blocklisted_plugins if p != "cov"
]
results = await async_execute_test_subprocess(
coverage_cmd
+ common_args
+ cov_blocklist
+ result_args
+ test_file_paths,
cwd=cwd,
env=pytest_test_env,
timeout=subprocess_timeout,
)
else:
results = await async_execute_test_subprocess(
pytest_cmd_list
+ common_args
+ blocklist_args
+ result_args
+ test_file_paths,
cwd=cwd,
env=pytest_test_env,
timeout=subprocess_timeout,
)
return (
result_file_path,
results,
coverage_database_file,
coverage_config_file,
)
async def async_run_benchmarking_tests( # noqa: PLR0913
@ -550,3 +590,72 @@ async def async_run_benchmarking_tests( # noqa: PLR0913
timeout=_subprocess_timeout(len(test_file_paths)),
)
return result_file_path, results
async def async_run_line_profile_tests( # noqa: PLR0913
test_files: TestFiles,
test_env: dict[str, str],
cwd: Path,
pytest_cmd: str = "pytest",
timeout: int | None = None,
result_file_name: str = "pytest_results.xml",
rootdir: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess[str]]:
"""Async version of :func:`run_line_profile_tests`."""
blocklisted_plugins = [
"codspeed",
"cov",
"benchmark",
"profiling",
"xdist",
"sugar",
]
pytest_cmd_list = [
sys.executable,
"-m",
*shlex.split(pytest_cmd),
]
test_file_paths = list(
{
str(tf.benchmarking_file_path)
for tf in test_files.test_files
if tf.benchmarking_file_path
}
)
pytest_args = [
*_base_pytest_args(rootdir, cwd),
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
"--codeflash_seconds=10.0",
]
if timeout is not None:
pytest_args.append(f"--timeout={timeout}")
result_file_path = get_run_tmp_file(
Path(result_file_name),
)
result_args = [
f"--junitxml={result_file_path.as_posix()}",
"-o",
"junit_logging=all",
]
lp_test_env = test_env.copy()
lp_test_env["PYTEST_PLUGINS"] = "codeflash_python.testing._pytest_plugin"
lp_test_env["LINE_PROFILE"] = "1"
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
results = await async_execute_test_subprocess(
pytest_cmd_list
+ pytest_args
+ blocklist_args
+ result_args
+ test_file_paths,
cwd=cwd,
env=lp_test_env,
timeout=_subprocess_timeout(len(test_file_paths)),
)
return result_file_path, results

View file

@ -8,7 +8,6 @@ from __future__ import annotations
import ast
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@ -183,7 +182,7 @@ def jit_disabled_env() -> dict[str, str]:
}
def establish_original_code_baseline( # noqa: PLR0913
async def establish_original_code_baseline( # noqa: PLR0913
test_files: TestFiles,
test_config: TestConfig,
test_env: dict[str, str],
@ -206,9 +205,9 @@ def establish_original_code_baseline( # noqa: PLR0913
from ..test_discovery.models import TestType # noqa: PLC0415
from ..testing._parse_results import parse_test_results # noqa: PLC0415
from ..testing._test_runner import ( # noqa: PLC0415
run_behavioral_tests,
run_benchmarking_tests,
run_line_profile_tests,
async_run_behavioral_tests,
async_run_benchmarking_tests,
async_run_line_profile_tests,
)
from .models import OriginalCodeBaseline # noqa: PLC0415
@ -218,14 +217,17 @@ def establish_original_code_baseline( # noqa: PLR0913
if precomputed_behavioral is not None:
behavioral_results = precomputed_behavioral
else:
xml_path, run_result, coverage_database_file, coverage_config_file = (
run_behavioral_tests(
test_files=test_files,
test_env=test_env,
cwd=cwd,
pytest_cmd=test_config.pytest_cmd,
enable_coverage=True,
)
(
xml_path,
run_result,
coverage_database_file,
coverage_config_file,
) = await async_run_behavioral_tests(
test_files=test_files,
test_env=test_env,
cwd=cwd,
pytest_cmd=test_config.pytest_cmd,
enable_coverage=True,
)
behavioral_results = parse_test_results(
test_xml_path=xml_path,
@ -255,26 +257,27 @@ def establish_original_code_baseline( # noqa: PLR0913
# are independent subprocesses that don't share state.
originals = add_async_perf_decorator(async_function, cwd)
try:
with ThreadPoolExecutor(max_workers=2) as pool:
lp_future = pool.submit(
run_line_profile_tests,
import asyncio # noqa: PLC0415
(
(lp_xml_path, lp_run_result),
(bm_xml_path, bm_run_result),
) = await asyncio.gather(
async_run_line_profile_tests(
test_files=test_files,
test_env=test_env,
cwd=cwd,
pytest_cmd=test_config.pytest_cmd,
result_file_name="pytest_lp_results.xml",
)
bm_future = pool.submit(
run_benchmarking_tests,
),
async_run_benchmarking_tests(
test_files=test_files,
test_env=test_env,
cwd=cwd,
pytest_cmd=test_config.pytest_cmd,
result_file_name="pytest_bm_results.xml",
)
lp_xml_path, lp_run_result = lp_future.result()
bm_xml_path, bm_run_result = bm_future.result()
),
)
finally:
revert_async_decorator(originals)

View file

@ -494,10 +494,10 @@ class TestEstablishOriginalCodeBaseline:
return test_files, test_config, test_env
@patch("codeflash_python.testing._parse_results.parse_test_results")
@patch("codeflash_python.testing._test_runner.run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.run_behavioral_tests")
def test_successful_baseline(
@patch("codeflash_python.testing._test_runner.async_run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.async_run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.async_run_behavioral_tests")
async def test_successful_baseline(
self,
mock_run_behavioral: MagicMock,
mock_run_benchmarking: MagicMock,
@ -536,7 +536,7 @@ class TestEstablishOriginalCodeBaseline:
benchmarking_results,
]
result = establish_original_code_baseline(
result = await establish_original_code_baseline(
test_files=test_files,
test_config=test_config,
test_env=test_env,
@ -551,10 +551,10 @@ class TestEstablishOriginalCodeBaseline:
assert result.runtime > 0
@patch("codeflash_python.testing._parse_results.parse_test_results")
@patch("codeflash_python.testing._test_runner.run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.run_behavioral_tests")
def test_empty_behavioral_returns_none(
@patch("codeflash_python.testing._test_runner.async_run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.async_run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.async_run_behavioral_tests")
async def test_empty_behavioral_returns_none(
self,
mock_run_behavioral: MagicMock,
mock_run_benchmarking: MagicMock,
@ -575,7 +575,7 @@ class TestEstablishOriginalCodeBaseline:
)
mock_parse_results.return_value = TestResults()
result = establish_original_code_baseline(
result = await establish_original_code_baseline(
test_files=test_files,
test_config=test_config,
test_env=test_env,
@ -585,10 +585,10 @@ class TestEstablishOriginalCodeBaseline:
assert result is None
@patch("codeflash_python.testing._parse_results.parse_test_results")
@patch("codeflash_python.testing._test_runner.run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.run_behavioral_tests")
def test_zero_benchmark_runtime_returns_none(
@patch("codeflash_python.testing._test_runner.async_run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.async_run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.async_run_behavioral_tests")
async def test_zero_benchmark_runtime_returns_none(
self,
mock_run_behavioral: MagicMock,
mock_run_benchmarking: MagicMock,
@ -631,7 +631,7 @@ class TestEstablishOriginalCodeBaseline:
zero_benchmarking,
]
result = establish_original_code_baseline(
result = await establish_original_code_baseline(
test_files=test_files,
test_config=test_config,
test_env=test_env,
@ -641,10 +641,10 @@ class TestEstablishOriginalCodeBaseline:
assert result is None
@patch("codeflash_python.testing._parse_results.parse_test_results")
@patch("codeflash_python.testing._test_runner.run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.run_behavioral_tests")
def test_precomputed_behavioral_skips_behavioral_run(
@patch("codeflash_python.testing._test_runner.async_run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.async_run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.async_run_behavioral_tests")
async def test_precomputed_behavioral_skips_behavioral_run(
self,
mock_run_behavioral: MagicMock,
mock_run_benchmarking: MagicMock,
@ -676,7 +676,7 @@ class TestEstablishOriginalCodeBaseline:
benchmarking_results,
]
result = establish_original_code_baseline(
result = await establish_original_code_baseline(
test_files=test_files,
test_config=test_config,
test_env=test_env,
@ -689,10 +689,10 @@ class TestEstablishOriginalCodeBaseline:
assert precomputed is result.behavior_test_results
@patch("codeflash_python.testing._parse_results.parse_test_results")
@patch("codeflash_python.testing._test_runner.run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.run_behavioral_tests")
def test_failed_regression_in_functions_to_remove(
@patch("codeflash_python.testing._test_runner.async_run_line_profile_tests")
@patch("codeflash_python.testing._test_runner.async_run_benchmarking_tests")
@patch("codeflash_python.testing._test_runner.async_run_behavioral_tests")
async def test_failed_regression_in_functions_to_remove(
self,
mock_run_behavioral: MagicMock,
mock_run_benchmarking: MagicMock,
@ -770,7 +770,7 @@ class TestEstablishOriginalCodeBaseline:
benchmarking_results,
]
result = establish_original_code_baseline(
result = await establish_original_code_baseline(
test_files=test_files,
test_config=test_config,
test_env=test_env,