first pass

This commit is contained in:
Kevin Turcios 2025-09-26 13:53:15 -07:00
parent d34cbbf793
commit fa705e1b5e
10 changed files with 2530 additions and 3 deletions

View file

@ -0,0 +1,43 @@
import asyncio
from typing import List, Union
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
"""
Async bubble sort implementation for testing.
"""
print("codeflash stdout: Async sorting list")
await asyncio.sleep(0.01)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
print(f"result: {result}")
return result
class AsyncBubbleSorter:
"""Class with async sorting method for testing."""
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
"""
Async bubble sort implementation within a class.
"""
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
# Add some async delay
await asyncio.sleep(0.005)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
return result

View file

@ -0,0 +1,16 @@
import time
import asyncio
async def retry_with_backoff(func, max_retries=3):
if max_retries < 1:
raise ValueError("max_retries must be at least 1")
last_exception = None
for attempt in range(max_retries):
try:
return await func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
time.sleep(0.0001 * attempt)
raise last_exception

View file

@ -0,0 +1,6 @@
[tool.codeflash]
disable-telemetry = true
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
module-root = "."
test-framework = "pytest"
tests-root = "tests"

View file

@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor):
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
is_async=bool(node.asynchronous),
)
)
@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
)
)
def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
@ -122,6 +132,7 @@ class FunctionToOptimize:
parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file.
is_async: Whether this function is defined as async.
The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root
@ -134,6 +145,7 @@ class FunctionToOptimize:
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
is_async: bool = False
@property
def top_level_parent_name(self) -> str:
@ -147,7 +159,11 @@ class FunctionToOptimize:
@property
def qualified_name(self) -> str:
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
if not self.parents:
return self.function_name
# Join all parent names with dots to handle nested classes properly
parent_path = ".".join(parent.name for parent in self.parents)
return f"{parent_path}.{self.function_name}"
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
@ -411,11 +427,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
)
)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if self.class_name is None and node.name == self.function_name:
self.is_top_level = True
self.function_has_args = any(
(
bool(node.args.args),
bool(node.args.kwonlyargs),
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
)
)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
# iterate over the class methods
if node.name == self.class_name:
for body_node in node.body:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
if (
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
):
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
@ -433,7 +465,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(

View file

@ -0,0 +1,27 @@
import os
import pathlib
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="main.py",
expected_unit_tests=0,
min_improvement_x=0.1,
coverage_expectations=[
CoverageExpectation(
function_name="retry_with_backoff",
expected_coverage=100.0,
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
)
],
)
cwd = (
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e"
).resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)
if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))

View file

@ -0,0 +1,286 @@
import tempfile
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import (
find_all_functions_in_file,
get_functions_to_optimize,
inspect_top_level_functions_or_methods,
)
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
def test_async_function_detection(temp_dir):
async_function = """
async def async_function_with_return():
await some_async_operation()
return 42
async def async_function_without_return():
await some_async_operation()
print("No return")
def regular_function():
return 10
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_function)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_function_with_return" in function_names
assert "regular_function" in function_names
assert "async_function_without_return" not in function_names
def test_async_method_in_class(temp_dir):
code_with_async_method = """
class AsyncClass:
async def async_method(self):
await self.do_something()
return "result"
async def async_method_no_return(self):
await self.do_something()
pass
def sync_method(self):
return "sync result"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code_with_async_method)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
function_names = [fn.function_name for fn in found_functions]
qualified_names = [fn.qualified_name for fn in found_functions]
assert "async_method" in function_names
assert "AsyncClass.async_method" in qualified_names
assert "sync_method" in function_names
assert "AsyncClass.sync_method" in qualified_names
assert "async_method_no_return" not in function_names
def test_nested_async_functions(temp_dir):
nested_async = """
async def outer_async():
async def inner_async():
return "inner"
result = await inner_async()
return result
def outer_sync():
async def inner_async():
return "inner from sync"
return inner_async
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(nested_async)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "outer_async" in function_names
assert "outer_sync" in function_names
assert "inner_async" not in function_names
def test_async_staticmethod_and_classmethod(temp_dir):
async_decorators = """
class MyClass:
@staticmethod
async def async_static_method():
await some_operation()
return "static result"
@classmethod
async def async_class_method(cls):
await cls.some_operation()
return "class result"
@property
async def async_property(self):
return await self.get_value()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_decorators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_static_method" in function_names
assert "async_class_method" in function_names
assert "async_property" not in function_names
def test_async_generator_functions(temp_dir):
async_generators = """
async def async_generator_with_return():
for i in range(10):
yield i
return "done"
async def async_generator_no_return():
for i in range(10):
yield i
async def regular_async_with_return():
result = await compute()
return result
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_generators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_generator_with_return" in function_names
assert "regular_async_with_return" in function_names
assert "async_generator_no_return" not in function_names
def test_inspect_async_top_level_functions(temp_dir):
code = """
async def top_level_async():
return 42
class AsyncContainer:
async def async_method(self):
async def nested_async():
return 1
return await nested_async()
@staticmethod
async def async_static():
return "static"
@classmethod
async def async_classmethod(cls):
return "classmethod"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code)
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
assert not result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_staticmethod
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_classmethod
def test_get_functions_to_optimize_with_async(temp_dir):
mixed_code = """
async def async_func_one():
return await operation_one()
def sync_func_one():
return operation_one()
async def async_func_two():
print("no return")
class MixedClass:
async def async_method(self):
return await self.operation()
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=file_path,
only_get_this_function=None,
test_cfg=test_config,
ignore_paths=[],
project_root=file_path.parent,
module_root=file_path.parent,
)
assert functions_count == 4
function_names = [fn.function_name for fn in functions[file_path]]
assert "async_func_one" in function_names
assert "sync_func_one" in function_names
assert "async_method" in function_names
assert "sync_method" in function_names
assert "async_func_two" not in function_names
def test_async_function_parents(temp_dir):
complex_structure = """
class OuterClass:
async def outer_method(self):
return 1
class InnerClass:
async def inner_method(self):
return 2
async def module_level_async():
class LocalClass:
async def local_method(self):
return 3
return LocalClass()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(complex_structure)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
for fn in found_functions:
if fn.function_name == "outer_method":
assert len(fn.parents) == 1
assert fn.parents[0].name == "OuterClass"
assert fn.qualified_name == "OuterClass.outer_method"
elif fn.function_name == "inner_method":
assert len(fn.parents) == 2
assert fn.parents[0].name == "OuterClass"
assert fn.parents[1].name == "InnerClass"
elif fn.function_name == "module_level_async":
assert len(fn.parents) == 0
assert fn.qualified_name == "module_level_async"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,285 @@
from __future__ import annotations
import asyncio
import os
import sqlite3
import tempfile
from pathlib import Path
import pytest
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import (
codeflash_behavior_async,
codeflash_performance_async,
)
from codeflash.verification.codeflash_capture import VerificationType
class TestAsyncWrapperSQLiteValidation:
@pytest.fixture
def test_env_setup(self, request):
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "0",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.fixture
def temp_db_path(self, test_env_setup):
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
yield db_path
if db_path.exists():
db_path.unlink()
@pytest.mark.asyncio
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def simple_async_add(a: int, b: int) -> int:
await asyncio.sleep(0.001)
return a + b
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
result = await simple_async_add(5, 3)
assert result == 8
assert temp_db_path.exists()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
assert cur.fetchone() is not None
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1
row = rows[0]
(test_module_path, test_class_name, test_function_name, function_getting_tested,
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
assert test_module_path == __name__
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
assert test_function_name == "test_behavior_async_basic_function"
assert function_getting_tested == "simple_async_add"
assert loop_index == 1
# Line ID will be the actual line number from the source code, not a simple counter
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
assert runtime > 0
assert verification_type == VerificationType.FUNCTION_CALL.value
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, return_val = unpickled_data
assert args == (5, 3)
assert kwargs == {}
assert return_val == 8
con.close()
@pytest.mark.asyncio
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_divide(a: int, b: int) -> float:
await asyncio.sleep(0.001)
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
result = await async_divide(10, 2)
assert result == 5.0
with pytest.raises(ValueError, match="Cannot divide by zero"):
await async_divide(10, 0)
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 2
success_row = rows[0]
success_data = pickle.loads(success_row[7]) # return_value_blob
args, kwargs, return_val = success_data
assert args == (10, 2)
assert return_val == 5.0
# Check exception record
exception_row = rows[1]
exception_data = pickle.loads(exception_row[7]) # return_value_blob
assert isinstance(exception_data, ValueError)
assert str(exception_data) == "Cannot divide by zero"
con.close()
@pytest.mark.asyncio
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
"""Test performance async decorator doesn't store to database."""
@codeflash_performance_async
async def async_multiply(a: int, b: int) -> int:
"""Async function for performance testing."""
await asyncio.sleep(0.002)
return a * b
result = await async_multiply(4, 7)
assert result == 28
assert not temp_db_path.exists()
captured = capsys.readouterr()
output_lines = captured.out.strip().split('\n')
assert len([line for line in output_lines if "!$######" in line]) == 1
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
assert "async_multiply" in closing_tag
timing_part = closing_tag.split(":")[-1].replace("######!", "")
timing_value = int(timing_part)
assert timing_value > 0 # Should have positive timing
@pytest.mark.asyncio
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_increment(value: int) -> int:
await asyncio.sleep(0.001)
return value + 1
# Call the function multiple times
results = []
for i in range(3):
result = await async_increment(i)
results.append(result)
assert results == [1, 2, 3]
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 3
actual_ids = [row[0] for row in rows]
assert len(actual_ids) == 3
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
assert actual_ids == expected_pattern
for i, (_, return_value_blob) in enumerate(rows):
args, kwargs, return_val = pickle.loads(return_value_blob)
assert args == (i,)
assert return_val == i + 1
con.close()
@pytest.mark.asyncio
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def complex_async_func(
pos_arg: str,
*args: int,
keyword_arg: str = "default",
**kwargs: str
) -> dict:
await asyncio.sleep(0.001)
return {
"pos_arg": pos_arg,
"args": args,
"keyword_arg": keyword_arg,
"kwargs": kwargs,
}
result = await complex_async_func(
"hello",
1, 2, 3,
keyword_arg="custom",
extra1="value1",
extra2="value2"
)
expected_result = {
"pos_arg": "hello",
"args": (1, 2, 3),
"keyword_arg": "custom",
"kwargs": {"extra1": "value1", "extra2": "value2"}
}
assert result == expected_result
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM test_results")
row = cur.fetchone()
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
assert stored_args == ("hello", 1, 2, 3)
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
assert stored_result == expected_result
con.close()
@pytest.mark.asyncio
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def schema_test_func() -> str:
return "schema_test"
await schema_test_func()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("PRAGMA table_info(test_results)")
columns = cur.fetchall()
expected_columns = [
(0, 'test_module_path', 'TEXT', 0, None, 0),
(1, 'test_class_name', 'TEXT', 0, None, 0),
(2, 'test_function_name', 'TEXT', 0, None, 0),
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
(4, 'loop_index', 'INTEGER', 0, None, 0),
(5, 'iteration_id', 'TEXT', 0, None, 0),
(6, 'runtime', 'INTEGER', 0, None, 0),
(7, 'return_value', 'BLOB', 0, None, 0),
(8, 'verification_type', 'TEXT', 0, None, 0)
]
assert columns == expected_columns
con.close()

View file

@ -0,0 +1,793 @@
import tempfile
from pathlib import Path
import uuid
import os
import pytest
from codeflash.code_utils.instrument_existing_tests import (
add_async_decorator_to_function,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
# @pytest.fixture
# def unique_test_iteration():
# """Provide a unique test iteration ID and clean up database after test."""
# # Generate unique iteration ID
# iteration_id = str(uuid.uuid4())[:8]
# # Store original environment variable
# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION")
# # Set unique iteration for this test
# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id
# try:
# yield iteration_id
# finally:
# # Cleanup: restore original environment and delete database file
# if original_iteration is not None:
# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration
# elif "CODEFLASH_TEST_ITERATION" in os.environ:
# del os.environ["CODEFLASH_TEST_ITERATION"]
# # Clean up database file
# try:
# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite"))
# if db_path.exists():
# db_path.unlink()
# except Exception:
# pass # Ignore cleanup errors
def test_async_decorator_application_behavior_mode():
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
def test_async_decorator_application_performance_mode():
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_performance_async
@codeflash_performance_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
def test_async_class_method_decorator_application():
async_class_code = '''
import asyncio
class Calculator:
"""Test class with async methods."""
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class Calculator:
"""Test class with async methods."""
@codeflash_behavior_async
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
func = FunctionToOptimize(
function_name="async_method",
file_path=Path("test_async.py"),
parents=[{"name": "Calculator", "type": "ClassDef"}],
is_async=True,
)
modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
def test_async_decorator_no_duplicate_application():
already_decorated_code = '''
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
import asyncio
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
expected_reformatted_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR)
assert not decorator_added
assert modified_code.strip() == expected_reformatted_code.strip()
def test_inject_profiling_async_function_behavior_mode(temp_dir):
source_module_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
async_test_code = '''
import asyncio
import pytest
from my_module import async_function
@pytest.mark.asyncio
async def test_async_function():
"""Test async function behavior."""
result = await async_function(5, 3)
assert result == 15
result2 = await async_function(2, 4)
assert result2 == 8
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_test_code)
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.BEHAVIOR
)
assert source_success is True
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
source_file.write_text(instrumented_source)
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
)
# For async functions, once source is decorated, test injection should fail
# This is expected behavior - async instrumentation happens at the decorator level
assert success is False
assert instrumented_test_code is None
def test_inject_profiling_async_function_performance_mode(temp_dir):
source_module_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
# Create the test file
async_test_code = '''
import asyncio
import pytest
from my_module import async_function
@pytest.mark.asyncio
async def test_async_function():
"""Test async function performance."""
result = await async_function(5, 3)
assert result == 15
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_test_code)
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.PERFORMANCE
)
assert source_success is True
assert instrumented_source is not None
assert "@codeflash_performance_async" in instrumented_source
# Check for the import with line continuation formatting
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_performance_async" in instrumented_source
# Write the instrumented source back
source_file.write_text(instrumented_source)
# Now test the full pipeline with source module path
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE
)
# For async functions, once source is decorated, test injection should fail
# This is expected behavior - async instrumentation happens at the decorator level
assert success is False
assert instrumented_test_code is None
def test_mixed_sync_async_instrumentation(temp_dir):
source_module_code = '''
import asyncio
def sync_function(x: int, y: int) -> int:
"""Regular sync function."""
return x * y
async def async_function(x: int, y: int) -> int:
"""Simple async function."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
mixed_test_code = '''
import asyncio
import pytest
from my_module import sync_function, async_function
@pytest.mark.asyncio
async def test_mixed_functions():
"""Test both sync and async functions."""
sync_result = sync_function(10, 5)
assert sync_result == 50
async_result = await async_function(3, 4)
assert async_result == 12
'''
test_file = temp_dir / "test_mixed.py"
test_file.write_text(mixed_test_code)
async_func = FunctionToOptimize(
function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True
)
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, async_func, TestingMode.BEHAVIOR
)
assert source_success
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
# Sync function should remain unchanged
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
# Write instrumented source back
source_file.write_text(instrumented_source)
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file,
[CodePosition(8, 18), CodePosition(11, 19)],
async_func,
temp_dir,
"pytest",
mode=TestingMode.BEHAVIOR,
)
# Async functions should not be instrumented at the test level
assert not success
assert instrumented_test_code is None
def test_async_function_qualified_name_handling():
nested_async_code = '''
import asyncio
class OuterClass:
class InnerClass:
async def nested_async_method(self, x: int) -> int:
"""Nested async method."""
await asyncio.sleep(0.001)
return x * 2
'''
func = FunctionToOptimize(
function_name="nested_async_method",
file_path=Path("test_nested.py"),
parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}],
is_async=True,
)
modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR)
expected_output = (
"""import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class OuterClass:
class InnerClass:
@codeflash_behavior_async
async def nested_async_method(self, x: int) -> int:
\"\"\"Nested async method.\"\"\"
await asyncio.sleep(0.001)
return x * 2
"""
)
assert modified_code.strip() == expected_output.strip()
def test_async_decorator_with_existing_decorators():
"""Test async decorator application when function already has other decorators."""
decorated_async_code = '''
import asyncio
from functools import wraps
def my_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
@my_decorator
async def async_function(x: int, y: int) -> int:
"""Async function with existing decorator."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR)
assert decorator_added
# Should add codeflash decorator above existing decorators
assert "@codeflash_behavior_async" in modified_code
assert "@my_decorator" in modified_code
# Codeflash decorator should come first
codeflash_pos = modified_code.find("@codeflash_behavior_async")
my_decorator_pos = modified_code.find("@my_decorator")
assert codeflash_pos < my_decorator_pos
def test_sync_function_not_affected_by_async_logic():
sync_function_code = '''
def sync_function(x: int, y: int) -> int:
"""Regular sync function."""
return x + y
'''
sync_func = FunctionToOptimize(
function_name="sync_function",
file_path=Path("test_sync.py"),
parents=[],
is_async=False,
)
modified_code, decorator_added = add_async_decorator_to_function(
sync_function_code, sync_func, TestingMode.BEHAVIOR
)
assert not decorator_added
assert modified_code == sync_function_code
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
source_module_code = '''
import asyncio
async def async_sorter(items):
"""Simple async sorter for testing."""
await asyncio.sleep(0.001)
return sorted(items)
'''
source_file = temp_dir / "async_sorter.py"
source_file.write_text(source_module_code)
test_code_multiple_calls = """
import asyncio
import pytest
from async_sorter import async_sorter
@pytest.mark.asyncio
async def test_single_call():
result = await async_sorter([42])
assert result == [42]
@pytest.mark.asyncio
async def test_multiple_calls():
result1 = await async_sorter([3, 1, 2])
result2 = await async_sorter([5, 4])
result3 = await async_sorter([9, 8, 7, 6])
assert result1 == [1, 2, 3]
assert result2 == [4, 5]
assert result3 == [6, 7, 8, 9]
"""
test_file = temp_dir / "test_async_sorter.py"
test_file.write_text(test_code_multiple_calls)
func = FunctionToOptimize(
function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True
)
# First instrument the source module with async decorators
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.BEHAVIOR
)
assert source_success
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
source_file.write_text(instrumented_source)
import ast
tree = ast.parse(test_code_multiple_calls)
call_positions = []
for node in ast.walk(tree):
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or (
hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter"
):
call_positions.append(CodePosition(node.lineno, node.col_offset))
assert len(call_positions) == 4
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
)
assert success
assert instrumented_test_code is not None
assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code
# Count occurrences of each line_id to verify numbering
line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'")
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
def test_async_behavior_decorator_return_values_and_test_ids():
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
import asyncio
import os
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
@codeflash_behavior_async
async def test_async_multiply(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.001) # Small delay to simulate async work
return x * y
test_env = {
"CODEFLASH_TEST_MODULE": "test_module",
"CODEFLASH_TEST_CLASS": None,
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
"CODEFLASH_CURRENT_LINE_ID": "0",
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "2",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
result = asyncio.run(test_async_multiply(6, 7))
assert result == 42, f"Expected return value 42, got {result}"
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
# Verify database exists and has data
assert db_path.exists(), f"Database file not created at {db_path}"
# Read and verify database contents
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
row = rows[0]
(
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) = row
assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'"
assert test_class is None, f"Expected test_class None, got '{test_class}'"
assert test_function == "test_async_multiply_function", (
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
)
assert function_name == "test_async_multiply", (
f"Expected function_name 'test_async_multiply', got '{function_name}'"
)
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'"
assert verification_type == "function_call", (
f"Expected verification_type 'function_call', got '{verification_type}'"
)
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, actual_return_value = unpickled_data
assert args == (6, 7), f"Expected args (6, 7), got {args}"
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}"
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
def test_async_decorator_comprehensive_return_values_and_test_ids():
import asyncio
import os
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file
@codeflash_behavior_async
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
"""Async function that multiplies x*y then adds z."""
await asyncio.sleep(0.001)
result = (x * y) + z
return result
test_env = {
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
"CODEFLASH_CURRENT_LINE_ID": "3",
"CODEFLASH_LOOP_INDEX": "2",
"CODEFLASH_TEST_ITERATION": "3",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
test_cases = [
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
{"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
]
results = []
for test_case in test_cases:
result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"]))
results.append(result)
# Verify each return value is exactly correct
assert result == test_case["expected"], (
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
)
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
assert db_path.exists(), f"Database not created at {db_path}"
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute(
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
)
rows = cur.fetchall()
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
for i, (
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) in enumerate(rows):
assert test_module == "test_comprehensive_module", (
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
)
assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
assert test_function == "test_comprehensive_async_function", (
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
)
assert function_name == "async_multiply_add", (
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
)
assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}"
assert verification_type == "function_call", (
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
)
expected_iteration_id = f"3_{i}"
assert iteration_id == expected_iteration_id, (
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
)
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
expected_args = test_cases[i]["args"]
expected_kwargs = test_cases[i]["kwargs"]
expected_return = test_cases[i]["expected"]
assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}"
assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
assert actual_return_value == expected_return, (
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
)
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]