Merge pull request #1518 from codeflash-ai/proper-async

refactor: inline async decorators to remove codeflash import dependency
This commit is contained in:
Kevin Turcios 2026-02-18 10:11:22 +00:00 committed by GitHub
commit 9d23a0ed1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 398 additions and 222 deletions

View file

@ -1497,73 +1497,207 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
return False
class AsyncDecoratorImportAdder(cst.CSTTransformer):
"""Transformer that adds the import for async decorators."""
ASYNC_HELPER_INLINE_CODE = """import asyncio
import gc
import os
import sqlite3
import time
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
self.mode = mode
self.has_import = False
import dill as pickle
def _get_decorator_name(self) -> str:
"""Get the decorator name based on the testing mode."""
if self.mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if self.mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
isinstance(node.module, cst.Attribute)
and isinstance(node.module.value, cst.Attribute)
and isinstance(node.module.value.value, cst.Name)
and node.module.value.value.value == "codeflash"
and node.module.value.attr.value == "code_utils"
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = self._get_decorator_name()
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True
def get_run_tmp_file(file_path):
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If the import is already there, don't add it again
if self.has_import:
return updated_node
# Choose import based on mode
decorator_name = self._get_decorator_name()
def extract_test_context_from_env():
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)
# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
def codeflash_behavior_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
"function_call",
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_performance_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
class_prefix = (test_class_name + ".") if test_class_name else ""
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_concurrency_async(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
return result
return async_wrapper
"""
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
def get_decorator_name_for_mode(mode: TestingMode) -> str:
if mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"
def write_async_helper_file(target_dir: Path) -> Path:
"""Write the async decorator helper file to the target directory."""
helper_path = target_dir / ASYNC_HELPER_FILENAME
if not helper_path.exists():
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
return helper_path
def add_async_decorator_to_function(
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
source_path: Path,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
project_root: Path | None = None,
) -> bool:
"""Add async decorator to an async function definition and write back to file.
Args:
----
source_path: Path to the source file to modify in-place.
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.
Returns:
-------
Boolean indicating whether the decorator was successfully added.
Writes a helper file containing the decorator implementation to project_root (or source directory
as fallback) and adds a standard import + decorator to the source file.
"""
if not function.is_async:
return False
try:
# Read source code
with source_path.open(encoding="utf8") as f:
source_code = f.read()
@ -1573,10 +1707,14 @@ def add_async_decorator_to_function(
decorator_transformer = AsyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)
# Add the import if decorator was added
if decorator_transformer.added_decorator:
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)
# Write the helper file to project_root (on sys.path) or source dir as fallback
helper_dir = project_root if project_root is not None else source_path.parent
write_async_helper_file(helper_dir)
# Add the import via CST so sort_imports can place it correctly
decorator_name = get_decorator_name_for_mode(mode)
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")
module = module.with_changes(body=[import_node, *list(module.body)])
modified_code = sort_imports(code=module.code, float_to_top=True)
except Exception as e:

View file

@ -1897,6 +1897,7 @@ class FunctionOptimizer:
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
self.cleanup_async_helper_file()
return Failure(baseline_result.failure())
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
@ -1908,6 +1909,7 @@ class FunctionOptimizer:
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
self.cleanup_async_helper_file()
return Failure("The threshold for test confidence was not met.")
return Success(
@ -2279,6 +2281,13 @@ class FunctionOptimizer:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
self.cleanup_async_helper_file()
def cleanup_async_helper_file(self) -> None:
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME
helper_path = self.project_root / ASYNC_HELPER_FILENAME
helper_path.unlink(missing_ok=True)
def establish_original_code_baseline(
self,
@ -2296,7 +2305,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
success = add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)
# Instrument codeflash capture
@ -2361,7 +2373,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)
try:
@ -2535,7 +2550,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)
try:
@ -2611,7 +2629,10 @@ class FunctionOptimizer:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)
try:
@ -2974,7 +2995,10 @@ class FunctionOptimizer:
try:
# Add concurrency decorator to the source function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.CONCURRENCY,
project_root=self.project_root,
)
# Run the concurrency benchmark tests

View file

@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool:
CoverageExpectation(
function_name="retry_with_backoff",
expected_coverage=100.0,
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
)
],
)

View file

@ -8,7 +8,9 @@ from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
ASYNC_HELPER_FILENAME,
add_async_decorator_to_function,
get_decorator_name_for_mode,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -55,16 +57,23 @@ async def test_async_sort():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
# For async functions, instrument the source module directly with decorators
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
# Verify the file was modified
# Verify the file was modified with exact expected output
instrumented_source = fto_path.read_text("utf-8")
assert (
'''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
decorated_original = original_code.replace(
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
# Add codeflash capture
instrument_codeflash_capture(func, {}, tests_root)
@ -142,6 +151,9 @@ async def test_async_sort():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -182,7 +194,9 @@ async def test_async_class_sort():
is_async=True,
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
@ -264,6 +278,9 @@ async def test_async_class_sort():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -294,16 +311,23 @@ async def test_async_perf():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
# Instrument the source module with async performance decorators
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path
)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
assert (
instrumented_source
== '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
decorated_original = original_code.replace(
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
instrument_codeflash_capture(func, {}, tests_root)
@ -359,6 +383,9 @@ async def test_async_perf():
# Clean up test files
if test_path.exists():
test_path.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -404,68 +431,24 @@ async def async_error_function(lst):
function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True
)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
# Verify the file was modified
instrumented_source = fto_path.read_text("utf-8")
expected_instrumented_source = """import asyncio
from typing import List, Union
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
\"\"\"
Async bubble sort implementation for testing.
\"\"\"
print("codeflash stdout: Async sorting list")
await asyncio.sleep(0.01)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
print(f"result: {result}")
return result
class AsyncBubbleSorter:
\"\"\"Class with async sorting method for testing.\"\"\"
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
\"\"\"
Async bubble sort implementation within a class.
\"\"\"
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
# Add some async delay
await asyncio.sleep(0.005)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
return result
@codeflash_behavior_async
async def async_error_function(lst):
\"\"\"Async function that raises an error for testing.\"\"\"
await asyncio.sleep(0.001) # Small delay
raise ValueError("Test error")
"""
assert expected_instrumented_source == instrumented_source
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
decorated_modified = modified_code.replace(
"async def async_error_function", f"@{decorator_name}\nasync def async_error_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
instrument_codeflash_capture(func, {}, tests_root)
opt = Optimizer(
@ -526,6 +509,9 @@ async def async_error_function(lst):
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -563,7 +549,9 @@ async def test_async_multi():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
instrument_codeflash_capture(func, {}, tests_root)
@ -636,6 +624,9 @@ async def test_async_multi():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -678,7 +669,9 @@ async def test_async_edge_cases():
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
instrument_codeflash_capture(func, {}, tests_root)
@ -753,6 +746,9 @@ async def test_async_edge_cases():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -987,7 +983,9 @@ async def test_mixed_sorting():
function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True
)
source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR)
source_success = add_async_decorator_to_function(
mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path
)
assert source_success
@ -1060,3 +1058,6 @@ async def test_mixed_sorting():
test_path.unlink()
if test_path_perf.exists():
test_path_perf.unlink()
helper_path = project_root_path / ASYNC_HELPER_FILENAME
if helper_path.exists():
helper_path.unlink()

View file

@ -6,7 +6,9 @@ from pathlib import Path
import pytest
from codeflash.code_utils.instrument_existing_tests import (
ASYNC_HELPER_FILENAME,
add_async_decorator_to_function,
get_decorator_name_for_mode,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -57,20 +59,6 @@ def test_async_decorator_application_behavior_mode(temp_dir):
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
@ -86,7 +74,16 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = async_function_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -94,20 +91,6 @@ def test_async_decorator_application_performance_mode(temp_dir):
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_performance_async
@codeflash_performance_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
@ -123,7 +106,16 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
code_with_decorator = async_function_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -132,20 +124,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir):
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_concurrency_async
@codeflash_concurrency_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
@ -161,7 +139,16 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY)
code_with_decorator = async_function_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
@ -182,27 +169,6 @@ class Calculator:
return a - b
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class Calculator:
"""Test class with async methods."""
@codeflash_behavior_async
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_class_code)
@ -217,11 +183,21 @@ class Calculator:
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = async_class_code.replace(
" async def async_method", f" @{decorator_name}\n async def async_method"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_no_duplicate_application(temp_dir):
# Case 1: Old-style import already present — injector should detect and skip
already_decorated_code = '''
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
import asyncio
@ -243,6 +219,30 @@ async def async_function(x: int, y: int) -> int:
# Should not add duplicate decorator
assert not decorator_added
# Case 2: Inline definition already present — injector should detect and skip
already_inline_code = '''
import asyncio
def codeflash_behavior_async(func):
return func
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
test_file2 = temp_dir / "test_async2.py"
test_file2.write_text(already_inline_code)
func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True)
decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR)
# Should not add duplicate decorator
assert not decorator_added2
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inject_profiling_async_function_behavior_mode(temp_dir):
@ -285,11 +285,18 @@ async def test_async_function():
assert source_success is True
# Verify the file was modified
# Verify the file was modified with exact expected output
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = source_module_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR
@ -340,12 +347,18 @@ async def test_async_function():
assert source_success is True
# Verify the file was modified
# Verify the file was modified with exact expected output
instrumented_source = source_file.read_text()
assert "@codeflash_performance_async" in instrumented_source
# Check for the import with line continuation formatting
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_performance_async" in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
code_with_decorator = source_module_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
# Now test the full pipeline with source module path
success, instrumented_test_code = inject_profiling_into_existing_test(
@ -406,11 +419,16 @@ async def test_mixed_functions():
# Verify the file was modified
instrumented_source = source_file.read_text()
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
# Sync function should remain unchanged
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = source_module_code.replace(
"async def async_function", f"@{decorator_name}\nasync def async_function"
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert instrumented_source.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
@ -446,24 +464,19 @@ class OuterClass:
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
expected_output = """import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class OuterClass:
class InnerClass:
@codeflash_behavior_async
async def nested_async_method(self, x: int) -> int:
\"\"\"Nested async method.\"\"\"
await asyncio.sleep(0.001)
return x * 2
"""
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_output.strip()
from codeflash.code_utils.formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = nested_async_code.replace(
" async def nested_async_method",
f" @{decorator_name}\n async def nested_async_method",
)
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
expected = sort_imports(code=code_with_import, float_to_top=True)
assert modified_code.strip() == expected.strip()
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")