Merge pull request #1518 from codeflash-ai/proper-async
refactor: inline async decorators to remove codeflash import dependency
This commit is contained in:
commit
9d23a0ed1c
5 changed files with 398 additions and 222 deletions
|
|
@ -1497,73 +1497,207 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
return False
|
||||
|
||||
|
||||
class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds the import for async decorators."""
|
||||
ASYNC_HELPER_INLINE_CODE = """import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
|
||||
self.mode = mode
|
||||
self.has_import = False
|
||||
import dill as pickle
|
||||
|
||||
def _get_decorator_name(self) -> str:
|
||||
"""Get the decorator name based on the testing mode."""
|
||||
if self.mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_async"
|
||||
if self.mode == TestingMode.CONCURRENCY:
|
||||
return "codeflash_concurrency_async"
|
||||
return "codeflash_performance_async"
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
# Check if the async decorator import is already present
|
||||
if (
|
||||
isinstance(node.module, cst.Attribute)
|
||||
and isinstance(node.module.value, cst.Attribute)
|
||||
and isinstance(node.module.value.value, cst.Name)
|
||||
and node.module.value.value.value == "codeflash"
|
||||
and node.module.value.attr.value == "code_utils"
|
||||
and node.module.attr.value == "codeflash_wrap_decorator"
|
||||
and not isinstance(node.names, cst.ImportStar)
|
||||
):
|
||||
decorator_name = self._get_decorator_name()
|
||||
for import_alias in node.names:
|
||||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
def get_run_tmp_file(file_path):
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
||||
# Choose import based on mode
|
||||
decorator_name = self._get_decorator_name()
|
||||
def extract_test_context_from_env():
|
||||
test_module = os.environ["CODEFLASH_TEST_MODULE"]
|
||||
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
|
||||
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
|
||||
if test_module and test_function:
|
||||
return (test_module, test_class if test_class else None, test_function)
|
||||
raise RuntimeError(
|
||||
"Test context environment variables not set - ensure tests are run through codeflash test runner"
|
||||
)
|
||||
|
||||
# Parse the import statement into a CST node
|
||||
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
|
||||
|
||||
# Add the import to the module's body
|
||||
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
|
||||
def codeflash_behavior_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
class_prefix = (test_class_name + ".") if test_class_name else ""
|
||||
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
codeflash_con = sqlite3.connect(db_path)
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
|
||||
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
|
||||
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"!######{test_stdout_tag}######!")
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
"function_call",
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
codeflash_con.close()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_performance_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
class_prefix = (test_class_name + ".") if test_class_name else ""
|
||||
test_stdout_tag = f"{test_module_name}:{class_prefix}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
|
||||
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
|
||||
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
|
||||
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
|
||||
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
|
||||
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
|
||||
return result
|
||||
return async_wrapper
|
||||
"""
|
||||
|
||||
ASYNC_HELPER_FILENAME = "codeflash_async_wrapper.py"
|
||||
|
||||
|
||||
def get_decorator_name_for_mode(mode: TestingMode) -> str:
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_async"
|
||||
if mode == TestingMode.CONCURRENCY:
|
||||
return "codeflash_concurrency_async"
|
||||
return "codeflash_performance_async"
|
||||
|
||||
|
||||
def write_async_helper_file(target_dir: Path) -> Path:
|
||||
"""Write the async decorator helper file to the target directory."""
|
||||
helper_path = target_dir / ASYNC_HELPER_FILENAME
|
||||
if not helper_path.exists():
|
||||
helper_path.write_text(ASYNC_HELPER_INLINE_CODE, "utf-8")
|
||||
return helper_path
|
||||
|
||||
|
||||
def add_async_decorator_to_function(
|
||||
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
|
||||
source_path: Path,
|
||||
function: FunctionToOptimize,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
project_root: Path | None = None,
|
||||
) -> bool:
|
||||
"""Add async decorator to an async function definition and write back to file.
|
||||
|
||||
Args:
|
||||
----
|
||||
source_path: Path to the source file to modify in-place.
|
||||
function: The FunctionToOptimize object representing the target async function.
|
||||
mode: The testing mode to determine which decorator to apply.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Boolean indicating whether the decorator was successfully added.
|
||||
Writes a helper file containing the decorator implementation to project_root (or source directory
|
||||
as fallback) and adds a standard import + decorator to the source file.
|
||||
|
||||
"""
|
||||
if not function.is_async:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read source code
|
||||
with source_path.open(encoding="utf8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
|
|
@ -1573,10 +1707,14 @@ def add_async_decorator_to_function(
|
|||
decorator_transformer = AsyncDecoratorAdder(function, mode)
|
||||
module = module.visit(decorator_transformer)
|
||||
|
||||
# Add the import if decorator was added
|
||||
if decorator_transformer.added_decorator:
|
||||
import_transformer = AsyncDecoratorImportAdder(mode)
|
||||
module = module.visit(import_transformer)
|
||||
# Write the helper file to project_root (on sys.path) or source dir as fallback
|
||||
helper_dir = project_root if project_root is not None else source_path.parent
|
||||
write_async_helper_file(helper_dir)
|
||||
# Add the import via CST so sort_imports can place it correctly
|
||||
decorator_name = get_decorator_name_for_mode(mode)
|
||||
import_node = cst.parse_statement(f"from codeflash_async_wrapper import {decorator_name}")
|
||||
module = module.with_changes(body=[import_node, *list(module.body)])
|
||||
|
||||
modified_code = sort_imports(code=module.code, float_to_top=True)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1897,6 +1897,7 @@ class FunctionOptimizer:
|
|||
if self.args.override_fixtures:
|
||||
restore_conftest(original_conftest_content)
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
self.cleanup_async_helper_file()
|
||||
return Failure(baseline_result.failure())
|
||||
|
||||
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
|
||||
|
|
@ -1908,6 +1909,7 @@ class FunctionOptimizer:
|
|||
if self.args.override_fixtures:
|
||||
restore_conftest(original_conftest_content)
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
self.cleanup_async_helper_file()
|
||||
return Failure("The threshold for test confidence was not met.")
|
||||
|
||||
return Success(
|
||||
|
|
@ -2279,6 +2281,13 @@ class FunctionOptimizer:
|
|||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
self.cleanup_async_helper_file()
|
||||
|
||||
def cleanup_async_helper_file(self) -> None:
|
||||
from codeflash.code_utils.instrument_existing_tests import ASYNC_HELPER_FILENAME
|
||||
|
||||
helper_path = self.project_root / ASYNC_HELPER_FILENAME
|
||||
helper_path.unlink(missing_ok=True)
|
||||
|
||||
def establish_original_code_baseline(
|
||||
self,
|
||||
|
|
@ -2296,7 +2305,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
success = add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.BEHAVIOR,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
# Instrument codeflash capture
|
||||
|
|
@ -2361,7 +2373,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.PERFORMANCE,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -2535,7 +2550,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.BEHAVIOR,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -2611,7 +2629,10 @@ class FunctionOptimizer:
|
|||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.PERFORMANCE,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -2974,7 +2995,10 @@ class FunctionOptimizer:
|
|||
try:
|
||||
# Add concurrency decorator to the source function
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.CONCURRENCY,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
|
||||
# Run the concurrency benchmark tests
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
expected_coverage=100.0,
|
||||
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
expected_lines=[9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
ASYNC_HELPER_FILENAME,
|
||||
add_async_decorator_to_function,
|
||||
get_decorator_name_for_mode,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -55,16 +57,23 @@ async def test_async_sort():
|
|||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
# For async functions, instrument the source module directly with decorators
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
# Verify the file was modified
|
||||
# Verify the file was modified with exact expected output
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert (
|
||||
'''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
|
||||
in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
decorated_original = original_code.replace(
|
||||
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
|
||||
# Add codeflash capture
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -142,6 +151,9 @@ async def test_async_sort():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -182,7 +194,9 @@ async def test_async_class_sort():
|
|||
is_async=True,
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
|
@ -264,6 +278,9 @@ async def test_async_class_sort():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -294,16 +311,23 @@ async def test_async_perf():
|
|||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
# Instrument the source module with async performance decorators
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.PERFORMANCE)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.PERFORMANCE, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
assert (
|
||||
instrumented_source
|
||||
== '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n'''
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
decorated_original = original_code.replace(
|
||||
"async def async_sorter", f"@{decorator_name}\nasync def async_sorter"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_original}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
||||
|
|
@ -359,6 +383,9 @@ async def test_async_perf():
|
|||
# Clean up test files
|
||||
if test_path.exists():
|
||||
test_path.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -404,68 +431,24 @@ async def async_error_function(lst):
|
|||
function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
# Verify the file was modified
|
||||
instrumented_source = fto_path.read_text("utf-8")
|
||||
|
||||
expected_instrumented_source = """import asyncio
|
||||
from typing import List, Union
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
\"\"\"
|
||||
Async bubble sort implementation for testing.
|
||||
\"\"\"
|
||||
print("codeflash stdout: Async sorting list")
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
print(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
class AsyncBubbleSorter:
|
||||
\"\"\"Class with async sorting method for testing.\"\"\"
|
||||
|
||||
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
\"\"\"
|
||||
Async bubble sort implementation within a class.
|
||||
\"\"\"
|
||||
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
|
||||
|
||||
# Add some async delay
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
return result
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_error_function(lst):
|
||||
\"\"\"Async function that raises an error for testing.\"\"\"
|
||||
await asyncio.sleep(0.001) # Small delay
|
||||
raise ValueError("Test error")
|
||||
"""
|
||||
assert expected_instrumented_source == instrumented_source
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
decorated_modified = modified_code.replace(
|
||||
"async def async_error_function", f"@{decorator_name}\nasync def async_error_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{decorated_modified}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
||||
opt = Optimizer(
|
||||
|
|
@ -526,6 +509,9 @@ async def async_error_function(lst):
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -563,7 +549,9 @@ async def test_async_multi():
|
|||
|
||||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -636,6 +624,9 @@ async def test_async_multi():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -678,7 +669,9 @@ async def test_async_edge_cases():
|
|||
|
||||
func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True)
|
||||
|
||||
source_success = add_async_decorator_to_function(fto_path, func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
fto_path, func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
instrument_codeflash_capture(func, {}, tests_root)
|
||||
|
|
@ -753,6 +746,9 @@ async def test_async_edge_cases():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -987,7 +983,9 @@ async def test_mixed_sorting():
|
|||
function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True
|
||||
)
|
||||
|
||||
source_success = add_async_decorator_to_function(mixed_fto_path, async_func, TestingMode.BEHAVIOR)
|
||||
source_success = add_async_decorator_to_function(
|
||||
mixed_fto_path, async_func, TestingMode.BEHAVIOR, project_root=project_root_path
|
||||
)
|
||||
|
||||
assert source_success
|
||||
|
||||
|
|
@ -1060,3 +1058,6 @@ async def test_mixed_sorting():
|
|||
test_path.unlink()
|
||||
if test_path_perf.exists():
|
||||
test_path_perf.unlink()
|
||||
helper_path = project_root_path / ASYNC_HELPER_FILENAME
|
||||
if helper_path.exists():
|
||||
helper_path.unlink()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
ASYNC_HELPER_FILENAME,
|
||||
add_async_decorator_to_function,
|
||||
get_decorator_name_for_mode,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -57,20 +59,6 @@ def test_async_decorator_application_behavior_mode(temp_dir):
|
|||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -86,7 +74,16 @@ async def async_function(x: int, y: int) -> int:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = async_function_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -94,20 +91,6 @@ def test_async_decorator_application_performance_mode(temp_dir):
|
|||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_performance_async
|
||||
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -123,7 +106,16 @@ async def async_function(x: int, y: int) -> int:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
code_with_decorator = async_function_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -132,20 +124,6 @@ def test_async_decorator_application_concurrency_mode(temp_dir):
|
|||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_concurrency_async
|
||||
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
|
|
@ -161,7 +139,16 @@ async def async_function(x: int, y: int) -> int:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY)
|
||||
code_with_decorator = async_function_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
@ -182,27 +169,6 @@ class Calculator:
|
|||
return a - b
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_class_code)
|
||||
|
||||
|
|
@ -217,11 +183,21 @@ class Calculator:
|
|||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = async_class_code.replace(
|
||||
" async def async_method", f" @{decorator_name}\n async def async_method"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_no_duplicate_application(temp_dir):
|
||||
# Case 1: Old-style import already present — injector should detect and skip
|
||||
already_decorated_code = '''
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
import asyncio
|
||||
|
|
@ -243,6 +219,30 @@ async def async_function(x: int, y: int) -> int:
|
|||
# Should not add duplicate decorator
|
||||
assert not decorator_added
|
||||
|
||||
# Case 2: Inline definition already present — injector should detect and skip
|
||||
already_inline_code = '''
|
||||
import asyncio
|
||||
|
||||
def codeflash_behavior_async(func):
|
||||
return func
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
test_file2 = temp_dir / "test_async2.py"
|
||||
test_file2.write_text(already_inline_code)
|
||||
|
||||
func2 = FunctionToOptimize(function_name="async_function", file_path=test_file2, parents=[], is_async=True)
|
||||
|
||||
decorator_added2 = add_async_decorator_to_function(test_file2, func2, TestingMode.BEHAVIOR)
|
||||
|
||||
# Should not add duplicate decorator
|
||||
assert not decorator_added2
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inject_profiling_async_function_behavior_mode(temp_dir):
|
||||
|
|
@ -285,11 +285,18 @@ async def test_async_function():
|
|||
|
||||
assert source_success is True
|
||||
|
||||
# Verify the file was modified
|
||||
# Verify the file was modified with exact expected output
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = source_module_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR
|
||||
|
|
@ -340,12 +347,18 @@ async def test_async_function():
|
|||
|
||||
assert source_success is True
|
||||
|
||||
# Verify the file was modified
|
||||
# Verify the file was modified with exact expected output
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_performance_async" in instrumented_source
|
||||
# Check for the import with line continuation formatting
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_performance_async" in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
code_with_decorator = source_module_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
# Now test the full pipeline with source module path
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
|
|
@ -406,11 +419,16 @@ async def test_mixed_functions():
|
|||
|
||||
# Verify the file was modified
|
||||
instrumented_source = source_file.read_text()
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
# Sync function should remain unchanged
|
||||
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = source_module_code.replace(
|
||||
"async def async_function", f"@{decorator_name}\nasync def async_function"
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert instrumented_source.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
|
||||
|
|
@ -446,24 +464,19 @@ class OuterClass:
|
|||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.BEHAVIOR)
|
||||
|
||||
expected_output = """import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
@codeflash_behavior_async
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
\"\"\"Nested async method.\"\"\"
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
"""
|
||||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_output.strip()
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
|
||||
code_with_decorator = nested_async_code.replace(
|
||||
" async def nested_async_method",
|
||||
f" @{decorator_name}\n async def nested_async_method",
|
||||
)
|
||||
code_with_import = f"from codeflash_async_wrapper import {decorator_name}\n{code_with_decorator}"
|
||||
expected = sort_imports(code=code_with_import, float_to_top=True)
|
||||
assert modified_code.strip() == expected.strip()
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
|
|
|
|||
Loading…
Reference in a new issue