mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
first pass
This commit is contained in:
parent
d34cbbf793
commit
fa705e1b5e
10 changed files with 2530 additions and 3 deletions
43
code_to_optimize/async_bubble_sort.py
Normal file
43
code_to_optimize/async_bubble_sort.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
"""
|
||||
Async bubble sort implementation for testing.
|
||||
"""
|
||||
print("codeflash stdout: Async sorting list")
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
print(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
class AsyncBubbleSorter:
|
||||
"""Class with async sorting method for testing."""
|
||||
|
||||
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
"""
|
||||
Async bubble sort implementation within a class.
|
||||
"""
|
||||
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
|
||||
|
||||
# Add some async delay
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
return result
|
||||
16
code_to_optimize/code_directories/async_e2e/main.py
Normal file
16
code_to_optimize/code_directories/async_e2e/main.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import time
|
||||
import asyncio
|
||||
|
||||
|
||||
async def retry_with_backoff(func, max_retries=3):
|
||||
if max_retries < 1:
|
||||
raise ValueError("max_retries must be at least 1")
|
||||
last_exception = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(0.0001 * attempt)
|
||||
raise last_exception
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
[tool.codeflash]
|
||||
disable-telemetry = true
|
||||
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
|
||||
module-root = "."
|
||||
test-framework = "pytest"
|
||||
tests-root = "tests"
|
||||
|
|
@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor):
|
|||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
is_async=bool(node.asynchronous),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
|
||||
# Check if the async function has a return statement and add it to the list
|
||||
if function_has_return_statement(node) and not function_is_a_property(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
|
||||
)
|
||||
)
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> None:
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
|
||||
|
|
@ -122,6 +132,7 @@ class FunctionToOptimize:
|
|||
parents: A list of parent scopes, which could be classes or functions.
|
||||
starting_line: The starting line number of the function in the file.
|
||||
ending_line: The ending line number of the function in the file.
|
||||
is_async: Whether this function is defined as async.
|
||||
|
||||
The qualified_name property provides the full name of the function, including
|
||||
any parent class or function names. The qualified_name_with_modules_from_root
|
||||
|
|
@ -134,6 +145,7 @@ class FunctionToOptimize:
|
|||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
starting_line: Optional[int] = None
|
||||
ending_line: Optional[int] = None
|
||||
is_async: bool = False
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
|
|
@ -147,7 +159,11 @@ class FunctionToOptimize:
|
|||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
|
||||
if not self.parents:
|
||||
return self.function_name
|
||||
# Join all parent names with dots to handle nested classes properly
|
||||
parent_path = ".".join(parent.name for parent in self.parents)
|
||||
return f"{parent_path}.{self.function_name}"
|
||||
|
||||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
|
@ -411,11 +427,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
)
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
if self.class_name is None and node.name == self.function_name:
|
||||
self.is_top_level = True
|
||||
self.function_has_args = any(
|
||||
(
|
||||
bool(node.args.args),
|
||||
bool(node.args.kwonlyargs),
|
||||
bool(node.args.kwarg),
|
||||
bool(node.args.posonlyargs),
|
||||
bool(node.args.vararg),
|
||||
)
|
||||
)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
# iterate over the class methods
|
||||
if node.name == self.class_name:
|
||||
for body_node in node.body:
|
||||
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
|
||||
if (
|
||||
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and body_node.name == self.function_name
|
||||
):
|
||||
self.is_top_level = True
|
||||
if any(
|
||||
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
|
||||
|
|
@ -433,7 +465,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
# This way, if we don't have the class name, we can still find the static method
|
||||
for body_node in node.body:
|
||||
if (
|
||||
isinstance(body_node, ast.FunctionDef)
|
||||
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and body_node.name == self.function_name
|
||||
and body_node.lineno in {self.line_no, self.line_no + 1}
|
||||
and any(
|
||||
|
|
|
|||
27
tests/scripts/end_to_end_test_async.py
Normal file
27
tests/scripts/end_to_end_test_async.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
|
||||
|
||||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="main.py",
|
||||
expected_unit_tests=0,
|
||||
min_improvement_x=0.1,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
expected_coverage=100.0,
|
||||
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
)
|
||||
],
|
||||
)
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e"
|
||||
).resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
|
||||
286
tests/test_async_function_discovery.py
Normal file
286
tests/test_async_function_discovery.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
find_all_functions_in_file,
|
||||
get_functions_to_optimize,
|
||||
inspect_top_level_functions_or_methods,
|
||||
)
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
def test_async_function_detection(temp_dir):
|
||||
async_function = """
|
||||
async def async_function_with_return():
|
||||
await some_async_operation()
|
||||
return 42
|
||||
|
||||
async def async_function_without_return():
|
||||
await some_async_operation()
|
||||
print("No return")
|
||||
|
||||
def regular_function():
|
||||
return 10
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_function)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_function_with_return" in function_names
|
||||
assert "regular_function" in function_names
|
||||
assert "async_function_without_return" not in function_names
|
||||
|
||||
|
||||
def test_async_method_in_class(temp_dir):
|
||||
code_with_async_method = """
|
||||
class AsyncClass:
|
||||
async def async_method(self):
|
||||
await self.do_something()
|
||||
return "result"
|
||||
|
||||
async def async_method_no_return(self):
|
||||
await self.do_something()
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
return "sync result"
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code_with_async_method)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
function_names = [fn.function_name for fn in found_functions]
|
||||
qualified_names = [fn.qualified_name for fn in found_functions]
|
||||
|
||||
assert "async_method" in function_names
|
||||
assert "AsyncClass.async_method" in qualified_names
|
||||
|
||||
assert "sync_method" in function_names
|
||||
assert "AsyncClass.sync_method" in qualified_names
|
||||
|
||||
assert "async_method_no_return" not in function_names
|
||||
|
||||
|
||||
def test_nested_async_functions(temp_dir):
|
||||
nested_async = """
|
||||
async def outer_async():
|
||||
async def inner_async():
|
||||
return "inner"
|
||||
|
||||
result = await inner_async()
|
||||
return result
|
||||
|
||||
def outer_sync():
|
||||
async def inner_async():
|
||||
return "inner from sync"
|
||||
|
||||
return inner_async
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(nested_async)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "outer_async" in function_names
|
||||
assert "outer_sync" in function_names
|
||||
assert "inner_async" not in function_names
|
||||
|
||||
|
||||
def test_async_staticmethod_and_classmethod(temp_dir):
|
||||
async_decorators = """
|
||||
class MyClass:
|
||||
@staticmethod
|
||||
async def async_static_method():
|
||||
await some_operation()
|
||||
return "static result"
|
||||
|
||||
@classmethod
|
||||
async def async_class_method(cls):
|
||||
await cls.some_operation()
|
||||
return "class result"
|
||||
|
||||
@property
|
||||
async def async_property(self):
|
||||
return await self.get_value()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_decorators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_static_method" in function_names
|
||||
assert "async_class_method" in function_names
|
||||
|
||||
assert "async_property" not in function_names
|
||||
|
||||
|
||||
def test_async_generator_functions(temp_dir):
|
||||
async_generators = """
|
||||
async def async_generator_with_return():
|
||||
for i in range(10):
|
||||
yield i
|
||||
return "done"
|
||||
|
||||
async def async_generator_no_return():
|
||||
for i in range(10):
|
||||
yield i
|
||||
|
||||
async def regular_async_with_return():
|
||||
result = await compute()
|
||||
return result
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_generators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_generator_with_return" in function_names
|
||||
assert "regular_async_with_return" in function_names
|
||||
assert "async_generator_no_return" not in function_names
|
||||
|
||||
|
||||
def test_inspect_async_top_level_functions(temp_dir):
|
||||
code = """
|
||||
async def top_level_async():
|
||||
return 42
|
||||
|
||||
class AsyncContainer:
|
||||
async def async_method(self):
|
||||
async def nested_async():
|
||||
return 1
|
||||
return await nested_async()
|
||||
|
||||
@staticmethod
|
||||
async def async_static():
|
||||
return "static"
|
||||
|
||||
@classmethod
|
||||
async def async_classmethod(cls):
|
||||
return "classmethod"
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code)
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
|
||||
assert result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
|
||||
assert not result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_staticmethod
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_classmethod
|
||||
|
||||
|
||||
def test_get_functions_to_optimize_with_async(temp_dir):
|
||||
mixed_code = """
|
||||
async def async_func_one():
|
||||
return await operation_one()
|
||||
|
||||
def sync_func_one():
|
||||
return operation_one()
|
||||
|
||||
async def async_func_two():
|
||||
print("no return")
|
||||
|
||||
class MixedClass:
|
||||
async def async_method(self):
|
||||
return await self.operation()
|
||||
|
||||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=file_path,
|
||||
only_get_this_function=None,
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[],
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
|
||||
assert functions_count == 4
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "async_func_one" in function_names
|
||||
assert "sync_func_one" in function_names
|
||||
assert "async_method" in function_names
|
||||
assert "sync_method" in function_names
|
||||
|
||||
assert "async_func_two" not in function_names
|
||||
|
||||
|
||||
def test_async_function_parents(temp_dir):
|
||||
complex_structure = """
|
||||
class OuterClass:
|
||||
async def outer_method(self):
|
||||
return 1
|
||||
|
||||
class InnerClass:
|
||||
async def inner_method(self):
|
||||
return 2
|
||||
|
||||
async def module_level_async():
|
||||
class LocalClass:
|
||||
async def local_method(self):
|
||||
return 3
|
||||
return LocalClass()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(complex_structure)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
|
||||
for fn in found_functions:
|
||||
if fn.function_name == "outer_method":
|
||||
assert len(fn.parents) == 1
|
||||
assert fn.parents[0].name == "OuterClass"
|
||||
assert fn.qualified_name == "OuterClass.outer_method"
|
||||
elif fn.function_name == "inner_method":
|
||||
assert len(fn.parents) == 2
|
||||
assert fn.parents[0].name == "OuterClass"
|
||||
assert fn.parents[1].name == "InnerClass"
|
||||
elif fn.function_name == "module_level_async":
|
||||
assert len(fn.parents) == 0
|
||||
assert fn.qualified_name == "module_level_async"
|
||||
1039
tests/test_async_run_and_parse_tests.py
Normal file
1039
tests/test_async_run_and_parse_tests.py
Normal file
File diff suppressed because it is too large
Load diff
285
tests/test_async_wrapper_sqlite_validation.py
Normal file
285
tests/test_async_wrapper_sqlite_validation.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
codeflash_performance_async,
|
||||
)
|
||||
from codeflash.verification.codeflash_capture import VerificationType
|
||||
|
||||
|
||||
class TestAsyncWrapperSQLiteValidation:
|
||||
|
||||
@pytest.fixture
|
||||
def test_env_setup(self, request):
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "0",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self, test_env_setup):
|
||||
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
|
||||
yield db_path
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def simple_async_add(a: int, b: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return a + b
|
||||
|
||||
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
|
||||
result = await simple_async_add(5, 3)
|
||||
|
||||
assert result == 8
|
||||
|
||||
assert temp_db_path.exists()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
|
||||
assert cur.fetchone() is not None
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
|
||||
(test_module_path, test_class_name, test_function_name, function_getting_tested,
|
||||
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
|
||||
|
||||
assert test_module_path == __name__
|
||||
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
|
||||
assert test_function_name == "test_behavior_async_basic_function"
|
||||
assert function_getting_tested == "simple_async_add"
|
||||
assert loop_index == 1
|
||||
# Line ID will be the actual line number from the source code, not a simple counter
|
||||
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
|
||||
assert runtime > 0
|
||||
assert verification_type == VerificationType.FUNCTION_CALL.value
|
||||
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, return_val = unpickled_data
|
||||
|
||||
assert args == (5, 3)
|
||||
assert kwargs == {}
|
||||
assert return_val == 8
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_divide(a: int, b: int) -> float:
|
||||
await asyncio.sleep(0.001)
|
||||
if b == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
return a / b
|
||||
|
||||
result = await async_divide(10, 2)
|
||||
assert result == 5.0
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot divide by zero"):
|
||||
await async_divide(10, 0)
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 2
|
||||
|
||||
success_row = rows[0]
|
||||
success_data = pickle.loads(success_row[7]) # return_value_blob
|
||||
args, kwargs, return_val = success_data
|
||||
assert args == (10, 2)
|
||||
assert return_val == 5.0
|
||||
|
||||
# Check exception record
|
||||
exception_row = rows[1]
|
||||
exception_data = pickle.loads(exception_row[7]) # return_value_blob
|
||||
assert isinstance(exception_data, ValueError)
|
||||
assert str(exception_data) == "Cannot divide by zero"
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
|
||||
"""Test performance async decorator doesn't store to database."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_multiply(a: int, b: int) -> int:
|
||||
"""Async function for performance testing."""
|
||||
await asyncio.sleep(0.002)
|
||||
return a * b
|
||||
|
||||
result = await async_multiply(4, 7)
|
||||
|
||||
assert result == 28
|
||||
|
||||
assert not temp_db_path.exists()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lines = captured.out.strip().split('\n')
|
||||
|
||||
assert len([line for line in output_lines if "!$######" in line]) == 1
|
||||
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
|
||||
|
||||
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
|
||||
assert "async_multiply" in closing_tag
|
||||
|
||||
timing_part = closing_tag.split(":")[-1].replace("######!", "")
|
||||
timing_value = int(timing_part)
|
||||
assert timing_value > 0 # Should have positive timing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_increment(value: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return value + 1
|
||||
|
||||
# Call the function multiple times
|
||||
results = []
|
||||
for i in range(3):
|
||||
result = await async_increment(i)
|
||||
results.append(result)
|
||||
|
||||
assert results == [1, 2, 3]
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
actual_ids = [row[0] for row in rows]
|
||||
assert len(actual_ids) == 3
|
||||
|
||||
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
|
||||
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
|
||||
assert actual_ids == expected_pattern
|
||||
|
||||
for i, (_, return_value_blob) in enumerate(rows):
|
||||
args, kwargs, return_val = pickle.loads(return_value_blob)
|
||||
assert args == (i,)
|
||||
assert return_val == i + 1
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def complex_async_func(
|
||||
pos_arg: str,
|
||||
*args: int,
|
||||
keyword_arg: str = "default",
|
||||
**kwargs: str
|
||||
) -> dict:
|
||||
await asyncio.sleep(0.001)
|
||||
return {
|
||||
"pos_arg": pos_arg,
|
||||
"args": args,
|
||||
"keyword_arg": keyword_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
result = await complex_async_func(
|
||||
"hello",
|
||||
1, 2, 3,
|
||||
keyword_arg="custom",
|
||||
extra1="value1",
|
||||
extra2="value2"
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
"pos_arg": "hello",
|
||||
"args": (1, 2, 3),
|
||||
"keyword_arg": "custom",
|
||||
"kwargs": {"extra1": "value1", "extra2": "value2"}
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM test_results")
|
||||
row = cur.fetchone()
|
||||
|
||||
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
|
||||
|
||||
assert stored_args == ("hello", 1, 2, 3)
|
||||
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
|
||||
assert stored_result == expected_result
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def schema_test_func() -> str:
|
||||
return "schema_test"
|
||||
|
||||
await schema_test_func()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("PRAGMA table_info(test_results)")
|
||||
columns = cur.fetchall()
|
||||
|
||||
expected_columns = [
|
||||
(0, 'test_module_path', 'TEXT', 0, None, 0),
|
||||
(1, 'test_class_name', 'TEXT', 0, None, 0),
|
||||
(2, 'test_function_name', 'TEXT', 0, None, 0),
|
||||
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
|
||||
(4, 'loop_index', 'INTEGER', 0, None, 0),
|
||||
(5, 'iteration_id', 'TEXT', 0, None, 0),
|
||||
(6, 'runtime', 'INTEGER', 0, None, 0),
|
||||
(7, 'return_value', 'BLOB', 0, None, 0),
|
||||
(8, 'verification_type', 'TEXT', 0, None, 0)
|
||||
]
|
||||
|
||||
assert columns == expected_columns
|
||||
con.close()
|
||||
|
||||
793
tests/test_instrument_async_tests.py
Normal file
793
tests/test_instrument_async_tests.py
Normal file
|
|
@ -0,0 +1,793 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
add_async_decorator_to_function,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, TestingMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def unique_test_iteration():
|
||||
# """Provide a unique test iteration ID and clean up database after test."""
|
||||
# # Generate unique iteration ID
|
||||
# iteration_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# # Store original environment variable
|
||||
# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION")
|
||||
|
||||
# # Set unique iteration for this test
|
||||
# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id
|
||||
|
||||
# try:
|
||||
# yield iteration_id
|
||||
# finally:
|
||||
# # Cleanup: restore original environment and delete database file
|
||||
# if original_iteration is not None:
|
||||
# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration
|
||||
# elif "CODEFLASH_TEST_ITERATION" in os.environ:
|
||||
# del os.environ["CODEFLASH_TEST_ITERATION"]
|
||||
|
||||
# # Clean up database file
|
||||
# try:
|
||||
# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite"))
|
||||
# if db_path.exists():
|
||||
# db_path.unlink()
|
||||
# except Exception:
|
||||
# pass # Ignore cleanup errors
|
||||
|
||||
|
||||
def test_async_decorator_application_behavior_mode():
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
def test_async_decorator_application_performance_mode():
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_performance_async
|
||||
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
def test_async_class_method_decorator_application():
|
||||
async_class_code = '''
|
||||
import asyncio
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_method",
|
||||
file_path=Path("test_async.py"),
|
||||
parents=[{"name": "Calculator", "type": "ClassDef"}],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
def test_async_decorator_no_duplicate_application():
|
||||
already_decorated_code = '''
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
import asyncio
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_reformatted_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert not decorator_added
|
||||
assert modified_code.strip() == expected_reformatted_code.strip()
|
||||
|
||||
|
||||
def test_inject_profiling_async_function_behavior_mode(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
async_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
"""Test async function behavior."""
|
||||
result = await async_function(5, 3)
|
||||
assert result == 15
|
||||
|
||||
result2 = await async_function(2, 4)
|
||||
assert result2 == 8
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_test_code)
|
||||
|
||||
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
|
||||
|
||||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success is True
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# For async functions, once source is decorated, test injection should fail
|
||||
# This is expected behavior - async instrumentation happens at the decorator level
|
||||
assert success is False
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
def test_inject_profiling_async_function_performance_mode(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
# Create the test file
|
||||
async_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
"""Test async function performance."""
|
||||
result = await async_function(5, 3)
|
||||
assert result == 15
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_test_code)
|
||||
|
||||
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
|
||||
|
||||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
assert source_success is True
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_performance_async" in instrumented_source
|
||||
# Check for the import with line continuation formatting
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_performance_async" in instrumented_source
|
||||
|
||||
# Write the instrumented source back
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
# Now test the full pipeline with source module path
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
# For async functions, once source is decorated, test injection should fail
|
||||
# This is expected behavior - async instrumentation happens at the decorator level
|
||||
assert success is False
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
def test_mixed_sync_async_instrumentation(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
def sync_function(x: int, y: int) -> int:
|
||||
"""Regular sync function."""
|
||||
return x * y
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
mixed_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import sync_function, async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_functions():
|
||||
"""Test both sync and async functions."""
|
||||
sync_result = sync_function(10, 5)
|
||||
assert sync_result == 50
|
||||
|
||||
async_result = await async_function(3, 4)
|
||||
assert async_result == 12
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_mixed.py"
|
||||
test_file.write_text(mixed_test_code)
|
||||
|
||||
async_func = FunctionToOptimize(
|
||||
function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True
|
||||
)
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, async_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
# Sync function should remain unchanged
|
||||
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
|
||||
|
||||
# Write instrumented source back
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file,
|
||||
[CodePosition(8, 18), CodePosition(11, 19)],
|
||||
async_func,
|
||||
temp_dir,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
|
||||
# Async functions should not be instrumented at the test level
|
||||
assert not success
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
def test_async_function_qualified_name_handling():
|
||||
nested_async_code = '''
|
||||
import asyncio
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
"""Nested async method."""
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="nested_async_method",
|
||||
file_path=Path("test_nested.py"),
|
||||
parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
expected_output = (
|
||||
"""import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
@codeflash_behavior_async
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
\"\"\"Nested async method.\"\"\"
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
"""
|
||||
)
|
||||
|
||||
assert modified_code.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_async_decorator_with_existing_decorators():
|
||||
"""Test async decorator application when function already has other decorators."""
|
||||
decorated_async_code = '''
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
def my_decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@my_decorator
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Async function with existing decorator."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
# Should add codeflash decorator above existing decorators
|
||||
assert "@codeflash_behavior_async" in modified_code
|
||||
assert "@my_decorator" in modified_code
|
||||
# Codeflash decorator should come first
|
||||
codeflash_pos = modified_code.find("@codeflash_behavior_async")
|
||||
my_decorator_pos = modified_code.find("@my_decorator")
|
||||
assert codeflash_pos < my_decorator_pos
|
||||
|
||||
|
||||
def test_sync_function_not_affected_by_async_logic():
|
||||
sync_function_code = '''
|
||||
def sync_function(x: int, y: int) -> int:
|
||||
"""Regular sync function."""
|
||||
return x + y
|
||||
'''
|
||||
|
||||
sync_func = FunctionToOptimize(
|
||||
function_name="sync_function",
|
||||
file_path=Path("test_sync.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(
|
||||
sync_function_code, sync_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert not decorator_added
|
||||
assert modified_code == sync_function_code
|
||||
|
||||
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
|
||||
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_sorter(items):
|
||||
"""Simple async sorter for testing."""
|
||||
await asyncio.sleep(0.001)
|
||||
return sorted(items)
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "async_sorter.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
test_code_multiple_calls = """
|
||||
import asyncio
|
||||
import pytest
|
||||
from async_sorter import async_sorter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_call():
|
||||
result = await async_sorter([42])
|
||||
assert result == [42]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls():
|
||||
result1 = await async_sorter([3, 1, 2])
|
||||
result2 = await async_sorter([5, 4])
|
||||
result3 = await async_sorter([9, 8, 7, 6])
|
||||
assert result1 == [1, 2, 3]
|
||||
assert result2 == [4, 5]
|
||||
assert result3 == [6, 7, 8, 9]
|
||||
"""
|
||||
|
||||
test_file = temp_dir / "test_async_sorter.py"
|
||||
test_file.write_text(test_code_multiple_calls)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True
|
||||
)
|
||||
|
||||
# First instrument the source module with async decorators
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
import ast
|
||||
|
||||
tree = ast.parse(test_code_multiple_calls)
|
||||
call_positions = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
|
||||
if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or (
|
||||
hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter"
|
||||
):
|
||||
call_positions.append(CodePosition(node.lineno, node.col_offset))
|
||||
|
||||
assert len(call_positions) == 4
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert success
|
||||
assert instrumented_test_code is not None
|
||||
|
||||
assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code
|
||||
|
||||
# Count occurrences of each line_id to verify numbering
|
||||
line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'")
|
||||
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
|
||||
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
|
||||
|
||||
|
||||
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
|
||||
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
|
||||
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
|
||||
|
||||
|
||||
|
||||
def test_async_behavior_decorator_return_values_and_test_ids():
|
||||
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def test_async_multiply(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.001) # Small delay to simulate async work
|
||||
return x * y
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_module",
|
||||
"CODEFLASH_TEST_CLASS": None,
|
||||
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "0",
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "2",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
result = asyncio.run(test_async_multiply(6, 7))
|
||||
|
||||
assert result == 42, f"Expected return value 42, got {result}"
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
|
||||
|
||||
# Verify database exists and has data
|
||||
assert db_path.exists(), f"Database file not created at {db_path}"
|
||||
|
||||
# Read and verify database contents
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
|
||||
|
||||
row = rows[0]
|
||||
(
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) = row
|
||||
|
||||
assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'"
|
||||
assert test_class is None, f"Expected test_class None, got '{test_class}'"
|
||||
assert test_function == "test_async_multiply_function", (
|
||||
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "test_async_multiply", (
|
||||
f"Expected function_name 'test_async_multiply', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
|
||||
assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'"
|
||||
assert verification_type == "function_call", (
|
||||
f"Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, actual_return_value = unpickled_data
|
||||
|
||||
assert args == (6, 7), f"Expected args (6, 7), got {args}"
|
||||
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
|
||||
|
||||
assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}"
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
|
||||
def test_async_decorator_comprehensive_return_values_and_test_ids():
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
|
||||
"""Async function that multiplies x*y then adds z."""
|
||||
await asyncio.sleep(0.001)
|
||||
result = (x * y) + z
|
||||
return result
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
|
||||
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "3",
|
||||
"CODEFLASH_LOOP_INDEX": "2",
|
||||
"CODEFLASH_TEST_ITERATION": "3",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
test_cases = [
|
||||
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
|
||||
{"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18
|
||||
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_case in test_cases:
|
||||
result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"]))
|
||||
results.append(result)
|
||||
|
||||
# Verify each return value is exactly correct
|
||||
assert result == test_case["expected"], (
|
||||
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
|
||||
)
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
|
||||
assert db_path.exists(), f"Database not created at {db_path}"
|
||||
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute(
|
||||
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
|
||||
|
||||
for i, (
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) in enumerate(rows):
|
||||
assert test_module == "test_comprehensive_module", (
|
||||
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
|
||||
)
|
||||
assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
|
||||
assert test_function == "test_comprehensive_async_function", (
|
||||
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "async_multiply_add", (
|
||||
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}"
|
||||
assert verification_type == "function_call", (
|
||||
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
|
||||
expected_iteration_id = f"3_{i}"
|
||||
assert iteration_id == expected_iteration_id, (
|
||||
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
|
||||
)
|
||||
|
||||
|
||||
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
|
||||
expected_args = test_cases[i]["args"]
|
||||
expected_kwargs = test_cases[i]["kwargs"]
|
||||
expected_return = test_cases[i]["expected"]
|
||||
|
||||
assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}"
|
||||
assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
|
||||
assert actual_return_value == expected_return, (
|
||||
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
|
||||
)
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
Loading…
Reference in a new issue