From fa705e1b5e02642bb8af76fbaab98d76b601e782 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 26 Sep 2025 13:53:15 -0700 Subject: [PATCH] first pass --- code_to_optimize/async_bubble_sort.py | 43 + .../code_directories/async_e2e/main.py | 16 + .../code_directories/async_e2e/pyproject.toml | 6 + .../async_e2e/tests/__init__.py | 0 codeflash/discovery/functions_to_optimize.py | 38 +- tests/scripts/end_to_end_test_async.py | 27 + tests/test_async_function_discovery.py | 286 +++++ tests/test_async_run_and_parse_tests.py | 1039 +++++++++++++++++ tests/test_async_wrapper_sqlite_validation.py | 285 +++++ tests/test_instrument_async_tests.py | 793 +++++++++++++ 10 files changed, 2530 insertions(+), 3 deletions(-) create mode 100644 code_to_optimize/async_bubble_sort.py create mode 100644 code_to_optimize/code_directories/async_e2e/main.py create mode 100644 code_to_optimize/code_directories/async_e2e/pyproject.toml create mode 100644 code_to_optimize/code_directories/async_e2e/tests/__init__.py create mode 100644 tests/scripts/end_to_end_test_async.py create mode 100644 tests/test_async_function_discovery.py create mode 100644 tests/test_async_run_and_parse_tests.py create mode 100644 tests/test_async_wrapper_sqlite_validation.py create mode 100644 tests/test_instrument_async_tests.py diff --git a/code_to_optimize/async_bubble_sort.py b/code_to_optimize/async_bubble_sort.py new file mode 100644 index 000000000..b87455299 --- /dev/null +++ b/code_to_optimize/async_bubble_sort.py @@ -0,0 +1,43 @@ +import asyncio +from typing import List, Union + + +async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]: + """ + Async bubble sort implementation for testing. + """ + print("codeflash stdout: Async sorting list") + + await asyncio.sleep(0.01) + + n = len(lst) + for i in range(n): + for j in range(0, n - i - 1): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + + result = lst.copy() + print(f"result: {result}") + return result + + +class AsyncBubbleSorter: + """Class with async sorting method for testing.""" + + async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]: + """ + Async bubble sort implementation within a class. + """ + print("codeflash stdout: AsyncBubbleSorter.sorter() called") + + # Add some async delay + await asyncio.sleep(0.005) + + n = len(lst) + for i in range(n): + for j in range(0, n - i - 1): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + + result = lst.copy() + return result diff --git a/code_to_optimize/code_directories/async_e2e/main.py b/code_to_optimize/code_directories/async_e2e/main.py new file mode 100644 index 000000000..317068a1c --- /dev/null +++ b/code_to_optimize/code_directories/async_e2e/main.py @@ -0,0 +1,16 @@ +import time +import asyncio + + +async def retry_with_backoff(func, max_retries=3): + if max_retries < 1: + raise ValueError("max_retries must be at least 1") + last_exception = None + for attempt in range(max_retries): + try: + return await func() + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + time.sleep(0.0001 * attempt) + raise last_exception diff --git a/code_to_optimize/code_directories/async_e2e/pyproject.toml b/code_to_optimize/code_directories/async_e2e/pyproject.toml new file mode 100644 index 000000000..d77155a9d --- /dev/null +++ b/code_to_optimize/code_directories/async_e2e/pyproject.toml @@ -0,0 +1,6 @@ +[tool.codeflash] +disable-telemetry = true +formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"] +module-root = "." +test-framework = "pytest" +tests-root = "tests" \ No newline at end of file diff --git a/code_to_optimize/code_directories/async_e2e/tests/__init__.py b/code_to_optimize/code_directories/async_e2e/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index bea01027b..dacbb1df6 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor): parents=list(reversed(ast_parents)), starting_line=pos.start.line, ending_line=pos.end.line, + is_async=bool(node.asynchronous), ) ) @@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor): FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) ) + def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: + # Check if the async function has a return statement and add it to the list + if function_has_return_statement(node) and not function_is_a_property(node): + self.functions.append( + FunctionToOptimize( + function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True + ) + ) + def generic_visit(self, node: ast.AST) -> None: if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)): self.ast_path.append(FunctionParent(node.name, node.__class__.__name__)) @@ -122,6 +132,7 @@ class FunctionToOptimize: parents: A list of parent scopes, which could be classes or functions. starting_line: The starting line number of the function in the file. ending_line: The ending line number of the function in the file. + is_async: Whether this function is defined as async. The qualified_name property provides the full name of the function, including any parent class or function names. The qualified_name_with_modules_from_root @@ -134,6 +145,7 @@ class FunctionToOptimize: parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] starting_line: Optional[int] = None ending_line: Optional[int] = None + is_async: bool = False @property def top_level_parent_name(self) -> str: @@ -147,7 +159,11 @@ class FunctionToOptimize: @property def qualified_name(self) -> str: - return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}" + if not self.parents: + return self.function_name + # Join all parent names with dots to handle nested classes properly + parent_path = ".".join(parent.name for parent in self.parents) + return f"{parent_path}.{self.function_name}" def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" @@ -411,11 +427,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): ) ) + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + if self.class_name is None and node.name == self.function_name: + self.is_top_level = True + self.function_has_args = any( + ( + bool(node.args.args), + bool(node.args.kwonlyargs), + bool(node.args.kwarg), + bool(node.args.posonlyargs), + bool(node.args.vararg), + ) + ) + def visit_ClassDef(self, node: ast.ClassDef) -> None: # iterate over the class methods if node.name == self.class_name: for body_node in node.body: - if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name: + if ( + isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and body_node.name == self.function_name + ): self.is_top_level = True if any( isinstance(decorator, ast.Name) and decorator.id == "classmethod" @@ -433,7 +465,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): # This way, if we don't have the class name, we can still find the static method for body_node in node.body: if ( - isinstance(body_node, ast.FunctionDef) + isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and body_node.name == self.function_name and body_node.lineno in {self.line_no, self.line_no + 1} and any( diff --git a/tests/scripts/end_to_end_test_async.py b/tests/scripts/end_to_end_test_async.py new file mode 100644 index 000000000..5aed8f8ca --- /dev/null +++ b/tests/scripts/end_to_end_test_async.py @@ -0,0 +1,27 @@ +import os +import pathlib + +from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries + + +def run_test(expected_improvement_pct: int) -> bool: + config = TestConfig( + file_path="main.py", + expected_unit_tests=0, + min_improvement_x=0.1, + coverage_expectations=[ + CoverageExpectation( + function_name="retry_with_backoff", + expected_coverage=100.0, + expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + ) + ], + ) + cwd = ( + pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e" + ).resolve() + return run_codeflash_command(cwd, config, expected_improvement_pct) + + +if __name__ == "__main__": + exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) \ No newline at end of file diff --git a/tests/test_async_function_discovery.py b/tests/test_async_function_discovery.py new file mode 100644 index 000000000..259d9ee24 --- /dev/null +++ b/tests/test_async_function_discovery.py @@ -0,0 +1,286 @@ +import tempfile +from pathlib import Path +import pytest + +from codeflash.discovery.functions_to_optimize import ( + find_all_functions_in_file, + get_functions_to_optimize, + inspect_top_level_functions_or_methods, +) +from codeflash.verification.verification_utils import TestConfig + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as temp: + yield Path(temp) + + +def test_async_function_detection(temp_dir): + async_function = """ +async def async_function_with_return(): + await some_async_operation() + return 42 + +async def async_function_without_return(): + await some_async_operation() + print("No return") + +def regular_function(): + return 10 +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(async_function) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "async_function_with_return" in function_names + assert "regular_function" in function_names + assert "async_function_without_return" not in function_names + + +def test_async_method_in_class(temp_dir): + code_with_async_method = """ +class AsyncClass: + async def async_method(self): + await self.do_something() + return "result" + + async def async_method_no_return(self): + await self.do_something() + pass + + def sync_method(self): + return "sync result" +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(code_with_async_method) + functions_found = find_all_functions_in_file(file_path) + + found_functions = functions_found[file_path] + function_names = [fn.function_name for fn in found_functions] + qualified_names = [fn.qualified_name for fn in found_functions] + + assert "async_method" in function_names + assert "AsyncClass.async_method" in qualified_names + + assert "sync_method" in function_names + assert "AsyncClass.sync_method" in qualified_names + + assert "async_method_no_return" not in function_names + + +def test_nested_async_functions(temp_dir): + nested_async = """ +async def outer_async(): + async def inner_async(): + return "inner" + + result = await inner_async() + return result + +def outer_sync(): + async def inner_async(): + return "inner from sync" + + return inner_async +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(nested_async) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "outer_async" in function_names + assert "outer_sync" in function_names + assert "inner_async" not in function_names + + +def test_async_staticmethod_and_classmethod(temp_dir): + async_decorators = """ +class MyClass: + @staticmethod + async def async_static_method(): + await some_operation() + return "static result" + + @classmethod + async def async_class_method(cls): + await cls.some_operation() + return "class result" + + @property + async def async_property(self): + return await self.get_value() +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(async_decorators) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "async_static_method" in function_names + assert "async_class_method" in function_names + + assert "async_property" not in function_names + + +def test_async_generator_functions(temp_dir): + async_generators = """ +async def async_generator_with_return(): + for i in range(10): + yield i + return "done" + +async def async_generator_no_return(): + for i in range(10): + yield i + +async def regular_async_with_return(): + result = await compute() + return result +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(async_generators) + functions_found = find_all_functions_in_file(file_path) + + function_names = [fn.function_name for fn in functions_found[file_path]] + + assert "async_generator_with_return" in function_names + assert "regular_async_with_return" in function_names + assert "async_generator_no_return" not in function_names + + +def test_inspect_async_top_level_functions(temp_dir): + code = """ +async def top_level_async(): + return 42 + +class AsyncContainer: + async def async_method(self): + async def nested_async(): + return 1 + return await nested_async() + + @staticmethod + async def async_static(): + return "static" + + @classmethod + async def async_classmethod(cls): + return "classmethod" +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(code) + + result = inspect_top_level_functions_or_methods(file_path, "top_level_async") + assert result.is_top_level + + result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer") + assert result.is_top_level + + result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer") + assert not result.is_top_level + + result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer") + assert result.is_top_level + assert result.is_staticmethod + + result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer") + assert result.is_top_level + assert result.is_classmethod + + +def test_get_functions_to_optimize_with_async(temp_dir): + mixed_code = """ +async def async_func_one(): + return await operation_one() + +def sync_func_one(): + return operation_one() + +async def async_func_two(): + print("no return") + +class MixedClass: + async def async_method(self): + return await self.operation() + + def sync_method(self): + return self.operation() +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(mixed_code) + + test_config = TestConfig( + tests_root="tests", + project_root_path=".", + test_framework="pytest", + tests_project_rootdir=Path() + ) + + functions, functions_count, _ = get_functions_to_optimize( + optimize_all=None, + replay_test=None, + file=file_path, + only_get_this_function=None, + test_cfg=test_config, + ignore_paths=[], + project_root=file_path.parent, + module_root=file_path.parent, + ) + + assert functions_count == 4 + + function_names = [fn.function_name for fn in functions[file_path]] + assert "async_func_one" in function_names + assert "sync_func_one" in function_names + assert "async_method" in function_names + assert "sync_method" in function_names + + assert "async_func_two" not in function_names + + +def test_async_function_parents(temp_dir): + complex_structure = """ +class OuterClass: + async def outer_method(self): + return 1 + + class InnerClass: + async def inner_method(self): + return 2 + +async def module_level_async(): + class LocalClass: + async def local_method(self): + return 3 + return LocalClass() +""" + + file_path = temp_dir / "test_file.py" + file_path.write_text(complex_structure) + functions_found = find_all_functions_in_file(file_path) + + found_functions = functions_found[file_path] + + for fn in found_functions: + if fn.function_name == "outer_method": + assert len(fn.parents) == 1 + assert fn.parents[0].name == "OuterClass" + assert fn.qualified_name == "OuterClass.outer_method" + elif fn.function_name == "inner_method": + assert len(fn.parents) == 2 + assert fn.parents[0].name == "OuterClass" + assert fn.parents[1].name == "InnerClass" + elif fn.function_name == "module_level_async": + assert len(fn.parents) == 0 + assert fn.qualified_name == "module_level_async" \ No newline at end of file diff --git a/tests/test_async_run_and_parse_tests.py b/tests/test_async_run_and_parse_tests.py new file mode 100644 index 000000000..b83be5c5a --- /dev/null +++ b/tests/test_async_run_and_parse_tests.py @@ -0,0 +1,1039 @@ +from __future__ import annotations + +import os +from argparse import Namespace +from pathlib import Path + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType +from codeflash.optimization.optimizer import Optimizer +from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture +from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators, inject_profiling_into_existing_test + +def test_async_bubble_sort_behavior_results() -> None: + test_code = """import asyncio +import pytest +from code_to_optimize.async_bubble_sort import async_sorter + + +@pytest.mark.asyncio +async def test_async_sort(): + input = [5, 4, 3, 2, 1, 0] + output = await async_sorter(input) + assert output == [0, 1, 2, 3, 4, 5] + + input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] + output = await async_sorter(input) + assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]""" + + test_path = ( + Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_bubble_sort_temp.py" + ).resolve() + test_path_perf = ( + Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_bubble_sort_perf_temp.py" + ).resolve() + fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve() + original_code = fto_path.read_text("utf-8") + + try: + # Write test file + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + # Create async function to optimize + func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) + + # For async functions, instrument the source module directly with decorators + source_success, instrumented_source = instrument_source_module_with_async_decorators( + fto_path, func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_behavior_async\n\n\n@codeflash_behavior_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' in instrumented_source + + # Write the instrumented source back + fto_path.write_text(instrumented_source, "utf-8") + + # Add codeflash capture + instrument_codeflash_capture(func, {}, tests_root) + + # Create optimizer + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_bubble_sort_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_sort" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + # Create function optimizer and set up test files + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + assert test_results is not None + assert test_results.test_results is not None + + results_list = test_results.test_results + assert results_list[0].id.function_getting_tested == "async_sorter" + assert results_list[0].id.test_class_name is None + assert results_list[0].id.test_function_name == "test_async_sort" + assert results_list[0].did_pass + assert results_list[0].runtime is None or results_list[0].runtime >= 0 + + expected_stdout = "codeflash stdout: Async sorting list\nresult: [0, 1, 2, 3, 4, 5]\n" + assert expected_stdout == results_list[0].stdout + + + assert results_list[1].id.function_getting_tested == "async_sorter" + assert results_list[1].id.test_function_name == "test_async_sort" + assert results_list[1].did_pass + + expected_stdout2 = "codeflash stdout: Async sorting list\nresult: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]\n" + assert expected_stdout2 == results_list[1].stdout + + finally: + # Restore original code + fto_path.write_text(original_code, "utf-8") + # Clean up test files + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() + + +def test_async_class_method_behavior_results() -> None: + """Test async class method behavior with run_and_parse_tests.""" + test_code = """import asyncio +import pytest +from code_to_optimize.async_bubble_sort import AsyncBubbleSorter + + +@pytest.mark.asyncio +async def test_async_class_sort(): + sorter = AsyncBubbleSorter() + input = [3, 1, 4, 1, 5] + output = await sorter.sorter(input) + assert output == [1, 1, 3, 4, 5]""" + + test_path = ( + Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_class_bubble_sort_temp.py" + ).resolve() + test_path_perf = ( + Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_class_bubble_sort_perf_temp.py" + ).resolve() + fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve() + original_code = fto_path.read_text("utf-8") + + try: + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + func = FunctionToOptimize( + function_name="sorter", + parents=[FunctionParent("AsyncBubbleSorter", "ClassDef")], + file_path=Path(fto_path), + is_async=True, + ) + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + fto_path, func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + assert "@codeflash_behavior_async" in instrumented_source + + fto_path.write_text(instrumented_source, "utf-8") + + instrument_codeflash_capture(func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_class_bubble_sort_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_class_sort" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + + assert test_results is not None + assert test_results.test_results is not None + + results_list = test_results.test_results + assert len(results_list) == 2, f"Expected 2 results but got {len(results_list)}: {[r.id.function_getting_tested for r in results_list]}" + + init_result = results_list[0] + sorter_result = results_list[1] + + + assert sorter_result.id.function_getting_tested == "sorter" + assert sorter_result.id.test_class_name is None + assert sorter_result.id.test_function_name == "test_async_class_sort" + assert sorter_result.did_pass + assert sorter_result.runtime is None or sorter_result.runtime >= 0 + + expected_stdout = "codeflash stdout: AsyncBubbleSorter.sorter() called\n" + assert expected_stdout == sorter_result.stdout + + assert ".__init__" in init_result.id.function_getting_tested + assert init_result.did_pass + + finally: + fto_path.write_text(original_code, "utf-8") + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() + + +def test_async_function_performance_mode() -> None: + test_code = """import asyncio +import pytest +from code_to_optimize.async_bubble_sort import async_sorter + + +@pytest.mark.asyncio +async def test_async_perf(): + input = [8, 7, 6, 5, 4, 3, 2, 1] + output = await async_sorter(input) + assert output == [1, 2, 3, 4, 5, 6, 7, 8]""" + + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_perf_temp.py").resolve() + fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve() + original_code = fto_path.read_text("utf-8") + + try: + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + # Create async function to optimize + func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) + + # Instrument the source module with async performance decorators + source_success, instrumented_source = instrument_source_module_with_async_decorators( + fto_path, func, TestingMode.PERFORMANCE + ) + + assert source_success + assert instrumented_source is not None + assert '''import asyncio\nfrom typing import List, Union\n\nfrom codeflash.code_utils.codeflash_wrap_decorator import \\\n codeflash_performance_async\n\n\n@codeflash_performance_async\nasync def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation for testing.\n """\n print("codeflash stdout: Async sorting list")\n \n await asyncio.sleep(0.01)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n print(f"result: {result}")\n return result\n\n\nclass AsyncBubbleSorter:\n """Class with async sorting method for testing."""\n \n async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:\n """\n Async bubble sort implementation within a class.\n """\n print("codeflash stdout: AsyncBubbleSorter.sorter() called")\n \n # Add some async delay\n await asyncio.sleep(0.005)\n \n n = len(lst)\n for i in range(n):\n for j in range(0, n - i - 1):\n if lst[j] > lst[j + 1]:\n lst[j], lst[j + 1] = lst[j + 1], lst[j]\n \n result = lst.copy()\n return result\n''' == instrumented_source + + fto_path.write_text(instrumented_source, "utf-8") + + instrument_codeflash_capture(func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_perf_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_perf" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path, # Same file for perf + ) + ] + ) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.PERFORMANCE, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + assert test_results is not None + assert test_results.test_results is not None + + finally: + # Restore original code + fto_path.write_text(original_code, "utf-8") + # Clean up test files + if test_path.exists(): + test_path.unlink() + + + +def test_async_function_error_handling() -> None: + test_code = """import asyncio +import pytest +from code_to_optimize.async_bubble_sort import async_error_function + + +@pytest.mark.asyncio +async def test_async_error(): + with pytest.raises(ValueError, match="Test error"): + await async_error_function([1, 2, 3])""" + + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_temp.py").resolve() + test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_error_perf_temp.py").resolve() + fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve() + original_code = fto_path.read_text("utf-8") + + try: + error_func_code = """ + +async def async_error_function(lst): + \"\"\"Async function that raises an error for testing.\"\"\" + await asyncio.sleep(0.001) # Small delay + raise ValueError("Test error") +""" + + modified_code = original_code + error_func_code + fto_path.write_text(modified_code, "utf-8") + + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + func = FunctionToOptimize(function_name="async_error_function", parents=[], file_path=Path(fto_path), is_async=True) + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + fto_path, func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + + expected_instrumented_source = """import asyncio +from typing import List, Union + +from codeflash.code_utils.codeflash_wrap_decorator import \\ + codeflash_behavior_async + + +async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]: + \"\"\" + Async bubble sort implementation for testing. + \"\"\" + print("codeflash stdout: Async sorting list") + + await asyncio.sleep(0.01) + + n = len(lst) + for i in range(n): + for j in range(0, n - i - 1): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + + result = lst.copy() + print(f"result: {result}") + return result + + +class AsyncBubbleSorter: + \"\"\"Class with async sorting method for testing.\"\"\" + + async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]: + \"\"\" + Async bubble sort implementation within a class. + \"\"\" + print("codeflash stdout: AsyncBubbleSorter.sorter() called") + + # Add some async delay + await asyncio.sleep(0.005) + + n = len(lst) + for i in range(n): + for j in range(0, n - i - 1): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + + result = lst.copy() + return result + + +@codeflash_behavior_async +async def async_error_function(lst): + \"\"\"Async function that raises an error for testing.\"\"\" + await asyncio.sleep(0.001) # Small delay + raise ValueError("Test error") +""" + assert expected_instrumented_source == instrumented_source + + fto_path.write_text(instrumented_source, "utf-8") + instrument_codeflash_capture(func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_error_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_error" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + assert test_results is not None + assert test_results.test_results is not None + assert len(test_results.test_results) >= 1 + + result = test_results.test_results[0] + assert result.id.function_getting_tested == "async_error_function" + assert result.did_pass + assert result.runtime is None or result.runtime >= 0 + + finally: + fto_path.write_text(original_code, "utf-8") + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() + + +def test_async_multiple_iterations() -> None: + test_code = """import asyncio +import pytest +from code_to_optimize.async_bubble_sort import async_sorter + + +@pytest.mark.asyncio +async def test_async_multi(): + input1 = [5, 4, 3] + output1 = await async_sorter(input1) + assert output1 == [3, 4, 5] + + input2 = [9, 7] + output2 = await async_sorter(input2) + assert output2 == [7, 9]""" + + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_temp.py").resolve() + test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_multi_perf_temp.py").resolve() + fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve() + original_code = fto_path.read_text("utf-8") + + try: + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + fto_path, func, TestingMode.BEHAVIOR + ) + + assert source_success + fto_path.write_text(instrumented_source, "utf-8") + instrument_codeflash_capture(func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "3" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_multi_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_multi" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=2, + pytest_max_loops=5, + testing_time=0.2, + ) + + assert test_results is not None + assert test_results.test_results is not None + assert len(test_results.test_results) >= 2 + + results_list = test_results.test_results + function_calls = [r for r in results_list if r.id.function_getting_tested == "async_sorter"] + assert len(function_calls) == 2 + + first_call = function_calls[0] + second_call = function_calls[1] + + assert first_call.stdout == "codeflash stdout: Async sorting list\nresult: [3, 4, 5]\n" + assert second_call.stdout == "codeflash stdout: Async sorting list\nresult: [7, 9]\n" + + assert first_call.did_pass + assert second_call.did_pass + assert first_call.runtime is None or first_call.runtime >= 0 + assert second_call.runtime is None or second_call.runtime >= 0 + + finally: + fto_path.write_text(original_code, "utf-8") + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() + + +def test_async_empty_input_edge_cases() -> None: + test_code = """import asyncio +import pytest +from code_to_optimize.async_bubble_sort import async_sorter + + +@pytest.mark.asyncio +async def test_async_edge_cases(): + # Empty list + empty = [] + result_empty = await async_sorter(empty) + assert result_empty == [] + + # Single item + single = [42] + result_single = await async_sorter(single) + assert result_single == [42] + + # Already sorted + sorted_list = [1, 2, 3, 4] + result_sorted = await async_sorter(sorted_list) + assert result_sorted == [1, 2, 3, 4]""" + + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_temp.py").resolve() + test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_async_edge_perf_temp.py").resolve() + fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_bubble_sort.py").resolve() + original_code = fto_path.read_text("utf-8") + + try: + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + func = FunctionToOptimize(function_name="async_sorter", parents=[], file_path=Path(fto_path), is_async=True) + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + fto_path, func, TestingMode.BEHAVIOR + ) + + assert source_success + fto_path.write_text(instrumented_source, "utf-8") + instrument_codeflash_capture(func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_async_edge_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_async_edge_cases" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + assert test_results is not None + assert test_results.test_results is not None + assert len(test_results.test_results) >= 3 # 3 function calls for edge cases + + results_list = test_results.test_results + function_calls = [r for r in results_list if r.id.function_getting_tested == "async_sorter"] + assert len(function_calls) == 3 + + # Verify all calls passed + for call in function_calls: + assert call.did_pass + assert call.runtime is None or call.runtime >= 0 + + empty_call = function_calls[0] + single_call = function_calls[1] + sorted_call = function_calls[2] + + assert empty_call.stdout == "codeflash stdout: Async sorting list\nresult: []\n" + assert single_call.stdout == "codeflash stdout: Async sorting list\nresult: [42]\n" + assert sorted_call.stdout == "codeflash stdout: Async sorting list\nresult: [1, 2, 3, 4]\n" + + finally: + fto_path.write_text(original_code, "utf-8") + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() + + +def test_sync_function_behavior_in_async_test_environment() -> None: + sync_sorter_code = """def sync_sorter(lst): + \"\"\"Synchronous bubble sort for comparison.\"\"\" + print("codeflash stdout: Sync sorting list") + n = len(lst) + for i in range(n): + for j in range(0, n - i - 1): + if lst[j] > lst[j + 1]: + lst[j], lst[j + 1] = lst[j + 1], lst[j] + result = lst.copy() + print(f"result: {result}") + return result +""" + + test_code = """from code_to_optimize.sync_bubble_sort import sync_sorter + + +def test_sync_sort(): + input = [5, 4, 3, 2, 1, 0] + output = sync_sorter(input) + assert output == [0, 1, 2, 3, 4, 5] + + input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] + output = sync_sorter(input) + assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]""" + + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_temp.py").resolve() + test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_sync_in_async_perf_temp.py").resolve() + sync_fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sync_bubble_sort.py").resolve() + + try: + with sync_fto_path.open("w") as f: + f.write(sync_sorter_code) + + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + func = FunctionToOptimize(function_name="sync_sorter", parents=[], file_path=Path(sync_fto_path), is_async=False) + + original_cwd = os.getcwd() + run_cwd = project_root_path + os.chdir(run_cwd) + + success, instrumented_test = inject_profiling_into_existing_test( + test_path, + [CodePosition(6, 13), CodePosition(10, 13)], # Lines where sync_sorter is called + func, + project_root_path, + "pytest", + mode=TestingMode.BEHAVIOR, + ) + os.chdir(original_cwd) + + assert success + assert instrumented_test is not None + + with test_path.open("w") as f: + f.write(instrumented_test) + + instrument_codeflash_capture(func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_sync_in_async_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_sync_sort" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + assert test_results is not None + assert test_results.test_results is not None + + results_list = test_results.test_results + assert results_list[0].id.function_getting_tested == "sync_sorter" + assert results_list[0].id.iteration_id == "1_0" + assert results_list[0].id.test_class_name is None + assert results_list[0].id.test_function_name == "test_sync_sort" + assert results_list[0].did_pass + assert results_list[0].runtime > 0 + + expected_stdout = "codeflash stdout: Sync sorting list\nresult: [0, 1, 2, 3, 4, 5]\n" + assert expected_stdout == results_list[0].stdout + + if len(results_list) > 1: + assert results_list[1].id.function_getting_tested == "sync_sorter" + assert results_list[1].id.iteration_id == "4_0" + assert results_list[1].id.test_function_name == "test_sync_sort" + assert results_list[1].did_pass + + expected_stdout2 = "codeflash stdout: Sync sorting list\nresult: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]\n" + assert expected_stdout2 == results_list[1].stdout + + finally: + if sync_fto_path.exists(): + sync_fto_path.unlink() + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() + + +def test_mixed_async_sync_function_calls() -> None: + mixed_module_code = """import asyncio +from typing import List, Union + + +def sync_quick_sort(lst: List[Union[int, float]]) -> List[Union[int, float]]: + \"\"\"Synchronous quick sort.\"\"\" + print("codeflash stdout: Sync quick sort") + if len(lst) <= 1: + return lst.copy() + pivot = lst[len(lst) // 2] + left = [x for x in lst if x < pivot] + middle = [x for x in lst if x == pivot] + right = [x for x in lst if x > pivot] + result = sync_quick_sort(left) + middle + sync_quick_sort(right) + print(f"result: {result}") + return result + + +async def async_merge_sort(lst: List[Union[int, float]]) -> List[Union[int, float]]: + \"\"\"Asynchronous merge sort.\"\"\" + print("codeflash stdout: Async merge sort") + await asyncio.sleep(0.001) # Small delay + + if len(lst) <= 1: + return lst.copy() + + mid = len(lst) // 2 + left = await async_merge_sort(lst[:mid]) + right = await async_merge_sort(lst[mid:]) + + # Merge + result = [] + i = j = 0 + while i < len(left) and j < len(right): + if left[i] <= right[j]: + result.append(left[i]) + i += 1 + else: + result.append(right[j]) + j += 1 + result.extend(left[i:]) + result.extend(right[j:]) + + print(f"result: {result}") + return result + +""" + + test_code = """import asyncio +import pytest +from code_to_optimize.mixed_sort import sync_quick_sort, async_merge_sort + + +@pytest.mark.asyncio +async def test_mixed_sorting(): + # Test sync function + sync_input = [3, 1, 4, 1, 5] + sync_output = sync_quick_sort(sync_input) + assert sync_output == [1, 1, 3, 4, 5] + + # Test async function + async_input = [9, 2, 6, 5, 3] + async_output = await async_merge_sort(async_input) + assert async_output == [2, 3, 5, 6, 9]""" + + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_temp.py").resolve() + test_path_perf = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_mixed_sort_perf_temp.py").resolve() + mixed_fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/mixed_sort.py").resolve() + + try: + with mixed_fto_path.open("w") as f: + f.write(mixed_module_code) + + with test_path.open("w") as f: + f.write(test_code) + + tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() + project_root_path = (Path(__file__).parent / "..").resolve() + + async_func = FunctionToOptimize(function_name="async_merge_sort", parents=[], file_path=Path(mixed_fto_path), is_async=True) + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + mixed_fto_path, async_func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + assert "@codeflash_behavior_async" in instrumented_source + assert "async def async_merge_sort" in instrumented_source + assert "def sync_quick_sort" in instrumented_source # Should preserve sync function + + mixed_fto_path.write_text(instrumented_source, "utf-8") + instrument_codeflash_capture(async_func, {}, tests_root) + + opt = Optimizer( + Namespace( + project_root=project_root_path, + disable_telemetry=True, + tests_root=tests_root, + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=project_root_path, + ) + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_env["CODEFLASH_TEST_MODULE"] = "code_to_optimize.tests.pytest.test_mixed_sort_temp" + test_env["CODEFLASH_TEST_CLASS"] = "" + test_env["CODEFLASH_TEST_FUNCTION"] = "test_mixed_sorting" + test_env["CODEFLASH_CURRENT_LINE_ID"] = "0" + test_type = TestType.EXISTING_UNIT_TEST + + func_optimizer = opt.create_function_optimizer(async_func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + + test_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + assert test_results is not None + assert test_results.test_results is not None + + results_list = test_results.test_results + async_calls = [r for r in results_list if r.id.function_getting_tested == "async_merge_sort"] + assert len(async_calls) >= 1 + + for call in async_calls: + assert call.did_pass + assert call.runtime is None or call.runtime >= 0 + assert "codeflash stdout: Async merge sort" in call.stdout + + finally: + if mixed_fto_path.exists(): + mixed_fto_path.unlink() + if test_path.exists(): + test_path.unlink() + if test_path_perf.exists(): + test_path_perf.unlink() \ No newline at end of file diff --git a/tests/test_async_wrapper_sqlite_validation.py b/tests/test_async_wrapper_sqlite_validation.py new file mode 100644 index 000000000..5cf7252f6 --- /dev/null +++ b/tests/test_async_wrapper_sqlite_validation.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import asyncio +import os +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +import dill as pickle + +from codeflash.code_utils.codeflash_wrap_decorator import ( + codeflash_behavior_async, + codeflash_performance_async, +) +from codeflash.verification.codeflash_capture import VerificationType + + +class TestAsyncWrapperSQLiteValidation: + + @pytest.fixture + def test_env_setup(self, request): + original_env = {} + test_env = { + "CODEFLASH_LOOP_INDEX": "1", + "CODEFLASH_TEST_ITERATION": "0", + "CODEFLASH_TEST_MODULE": __name__, + "CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation", + "CODEFLASH_TEST_FUNCTION": request.node.name, + "CODEFLASH_CURRENT_LINE_ID": "test_unit", + } + + for key, value in test_env.items(): + original_env[key] = os.environ.get(key) + os.environ[key] = value + + yield test_env + + for key, original_value in original_env.items(): + if original_value is None: + os.environ.pop(key, None) + else: + os.environ[key] = original_value + + @pytest.fixture + def temp_db_path(self, test_env_setup): + iteration = test_env_setup["CODEFLASH_TEST_ITERATION"] + from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file + db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite")) + + yield db_path + + if db_path.exists(): + db_path.unlink() + + @pytest.mark.asyncio + async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path): + + @codeflash_behavior_async + async def simple_async_add(a: int, b: int) -> int: + await asyncio.sleep(0.001) + return a + b + + os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59' + result = await simple_async_add(5, 3) + + assert result == 8 + + assert temp_db_path.exists() + + con = sqlite3.connect(temp_db_path) + cur = con.cursor() + + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'") + assert cur.fetchone() is not None + + cur.execute("SELECT * FROM test_results") + rows = cur.fetchall() + + assert len(rows) == 1 + row = rows[0] + + (test_module_path, test_class_name, test_function_name, function_getting_tested, + loop_index, iteration_id, runtime, return_value_blob, verification_type) = row + + assert test_module_path == __name__ + assert test_class_name == "TestAsyncWrapperSQLiteValidation" + assert test_function_name == "test_behavior_async_basic_function" + assert function_getting_tested == "simple_async_add" + assert loop_index == 1 + # Line ID will be the actual line number from the source code, not a simple counter + assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0") + assert runtime > 0 + assert verification_type == VerificationType.FUNCTION_CALL.value + + unpickled_data = pickle.loads(return_value_blob) + args, kwargs, return_val = unpickled_data + + assert args == (5, 3) + assert kwargs == {} + assert return_val == 8 + + con.close() + + @pytest.mark.asyncio + async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path): + + @codeflash_behavior_async + async def async_divide(a: int, b: int) -> float: + await asyncio.sleep(0.001) + if b == 0: + raise ValueError("Cannot divide by zero") + return a / b + + result = await async_divide(10, 2) + assert result == 5.0 + + with pytest.raises(ValueError, match="Cannot divide by zero"): + await async_divide(10, 0) + + con = sqlite3.connect(temp_db_path) + cur = con.cursor() + cur.execute("SELECT * FROM test_results ORDER BY iteration_id") + rows = cur.fetchall() + + assert len(rows) == 2 + + success_row = rows[0] + success_data = pickle.loads(success_row[7]) # return_value_blob + args, kwargs, return_val = success_data + assert args == (10, 2) + assert return_val == 5.0 + + # Check exception record + exception_row = rows[1] + exception_data = pickle.loads(exception_row[7]) # return_value_blob + assert isinstance(exception_data, ValueError) + assert str(exception_data) == "Cannot divide by zero" + + con.close() + + @pytest.mark.asyncio + async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys): + """Test performance async decorator doesn't store to database.""" + + @codeflash_performance_async + async def async_multiply(a: int, b: int) -> int: + """Async function for performance testing.""" + await asyncio.sleep(0.002) + return a * b + + result = await async_multiply(4, 7) + + assert result == 28 + + assert not temp_db_path.exists() + + captured = capsys.readouterr() + output_lines = captured.out.strip().split('\n') + + assert len([line for line in output_lines if "!$######" in line]) == 1 + assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1 + + closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0] + assert "async_multiply" in closing_tag + + timing_part = closing_tag.split(":")[-1].replace("######!", "") + timing_value = int(timing_part) + assert timing_value > 0 # Should have positive timing + + @pytest.mark.asyncio + async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path): + + @codeflash_behavior_async + async def async_increment(value: int) -> int: + await asyncio.sleep(0.001) + return value + 1 + + # Call the function multiple times + results = [] + for i in range(3): + result = await async_increment(i) + results.append(result) + + assert results == [1, 2, 3] + + con = sqlite3.connect(temp_db_path) + cur = con.cursor() + cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id") + rows = cur.fetchall() + + assert len(rows) == 3 + + actual_ids = [row[0] for row in rows] + assert len(actual_ids) == 3 + + base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199" + expected_pattern = [f"{base_pattern}_{i}" for i in range(3)] + assert actual_ids == expected_pattern + + for i, (_, return_value_blob) in enumerate(rows): + args, kwargs, return_val = pickle.loads(return_value_blob) + assert args == (i,) + assert return_val == i + 1 + + con.close() + + @pytest.mark.asyncio + async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path): + + @codeflash_behavior_async + async def complex_async_func( + pos_arg: str, + *args: int, + keyword_arg: str = "default", + **kwargs: str + ) -> dict: + await asyncio.sleep(0.001) + return { + "pos_arg": pos_arg, + "args": args, + "keyword_arg": keyword_arg, + "kwargs": kwargs, + } + + result = await complex_async_func( + "hello", + 1, 2, 3, + keyword_arg="custom", + extra1="value1", + extra2="value2" + ) + + expected_result = { + "pos_arg": "hello", + "args": (1, 2, 3), + "keyword_arg": "custom", + "kwargs": {"extra1": "value1", "extra2": "value2"} + } + + assert result == expected_result + + con = sqlite3.connect(temp_db_path) + cur = con.cursor() + cur.execute("SELECT return_value FROM test_results") + row = cur.fetchone() + + stored_args, stored_kwargs, stored_result = pickle.loads(row[0]) + + assert stored_args == ("hello", 1, 2, 3) + assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"} + assert stored_result == expected_result + + con.close() + + @pytest.mark.asyncio + async def test_database_schema_validation(self, test_env_setup, temp_db_path): + + @codeflash_behavior_async + async def schema_test_func() -> str: + return "schema_test" + + await schema_test_func() + + con = sqlite3.connect(temp_db_path) + cur = con.cursor() + + cur.execute("PRAGMA table_info(test_results)") + columns = cur.fetchall() + + expected_columns = [ + (0, 'test_module_path', 'TEXT', 0, None, 0), + (1, 'test_class_name', 'TEXT', 0, None, 0), + (2, 'test_function_name', 'TEXT', 0, None, 0), + (3, 'function_getting_tested', 'TEXT', 0, None, 0), + (4, 'loop_index', 'INTEGER', 0, None, 0), + (5, 'iteration_id', 'TEXT', 0, None, 0), + (6, 'runtime', 'INTEGER', 0, None, 0), + (7, 'return_value', 'BLOB', 0, None, 0), + (8, 'verification_type', 'TEXT', 0, None, 0) + ] + + assert columns == expected_columns + con.close() + diff --git a/tests/test_instrument_async_tests.py b/tests/test_instrument_async_tests.py new file mode 100644 index 000000000..1149f42f2 --- /dev/null +++ b/tests/test_instrument_async_tests.py @@ -0,0 +1,793 @@ +import tempfile +from pathlib import Path +import uuid +import os + +import pytest + +from codeflash.code_utils.instrument_existing_tests import ( + add_async_decorator_to_function, + inject_profiling_into_existing_test, +) +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodePosition, TestingMode + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as temp: + yield Path(temp) + + +# @pytest.fixture +# def unique_test_iteration(): +# """Provide a unique test iteration ID and clean up database after test.""" +# # Generate unique iteration ID +# iteration_id = str(uuid.uuid4())[:8] + +# # Store original environment variable +# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION") + +# # Set unique iteration for this test +# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id + +# try: +# yield iteration_id +# finally: +# # Cleanup: restore original environment and delete database file +# if original_iteration is not None: +# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration +# elif "CODEFLASH_TEST_ITERATION" in os.environ: +# del os.environ["CODEFLASH_TEST_ITERATION"] + +# # Clean up database file +# try: +# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file + +# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite")) +# if db_path.exists(): +# db_path.unlink() +# except Exception: +# pass # Ignore cleanup errors + + +def test_async_decorator_application_behavior_mode(): + async_function_code = ''' +import asyncio + +async def async_function(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.01) + return x * y +''' + + expected_decorated_code = ''' +import asyncio + +from codeflash.code_utils.codeflash_wrap_decorator import \\ + codeflash_behavior_async + + +@codeflash_behavior_async +async def async_function(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.01) + return x * y +''' + + func = FunctionToOptimize( + function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True + ) + + modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR) + + assert decorator_added + assert modified_code.strip() == expected_decorated_code.strip() + + +def test_async_decorator_application_performance_mode(): + async_function_code = ''' +import asyncio + +async def async_function(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.01) + return x * y +''' + + expected_decorated_code = ''' +import asyncio + +from codeflash.code_utils.codeflash_wrap_decorator import \\ + codeflash_performance_async + + +@codeflash_performance_async +async def async_function(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.01) + return x * y +''' + + func = FunctionToOptimize( + function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True + ) + + modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE) + + assert decorator_added + assert modified_code.strip() == expected_decorated_code.strip() + + +def test_async_class_method_decorator_application(): + async_class_code = ''' +import asyncio + +class Calculator: + """Test class with async methods.""" + + async def async_method(self, a: int, b: int) -> int: + """Async method in class.""" + await asyncio.sleep(0.005) + return a ** b + + def sync_method(self, a: int, b: int) -> int: + """Sync method in class.""" + return a - b +''' + + expected_decorated_code = ''' +import asyncio + +from codeflash.code_utils.codeflash_wrap_decorator import \\ + codeflash_behavior_async + + +class Calculator: + """Test class with async methods.""" + + @codeflash_behavior_async + async def async_method(self, a: int, b: int) -> int: + """Async method in class.""" + await asyncio.sleep(0.005) + return a ** b + + def sync_method(self, a: int, b: int) -> int: + """Sync method in class.""" + return a - b +''' + + func = FunctionToOptimize( + function_name="async_method", + file_path=Path("test_async.py"), + parents=[{"name": "Calculator", "type": "ClassDef"}], + is_async=True, + ) + + modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR) + + assert decorator_added + assert modified_code.strip() == expected_decorated_code.strip() + + +def test_async_decorator_no_duplicate_application(): + already_decorated_code = ''' +from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async +import asyncio + +@codeflash_behavior_async +async def async_function(x: int, y: int) -> int: + """Already decorated async function.""" + await asyncio.sleep(0.01) + return x * y +''' + + expected_reformatted_code = ''' +import asyncio + +from codeflash.code_utils.codeflash_wrap_decorator import \\ + codeflash_behavior_async + + +@codeflash_behavior_async +async def async_function(x: int, y: int) -> int: + """Already decorated async function.""" + await asyncio.sleep(0.01) + return x * y +''' + + func = FunctionToOptimize( + function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True + ) + + modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR) + + assert not decorator_added + assert modified_code.strip() == expected_reformatted_code.strip() + + +def test_inject_profiling_async_function_behavior_mode(temp_dir): + source_module_code = ''' +import asyncio + +async def async_function(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.01) + return x * y +''' + + source_file = temp_dir / "my_module.py" + source_file.write_text(source_module_code) + + async_test_code = ''' +import asyncio +import pytest +from my_module import async_function + +@pytest.mark.asyncio +async def test_async_function(): + """Test async function behavior.""" + result = await async_function(5, 3) + assert result == 15 + + result2 = await async_function(2, 4) + assert result2 == 8 +''' + + test_file = temp_dir / "test_async.py" + test_file.write_text(async_test_code) + + func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True) + + # First instrument the source module + from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + source_file, func, TestingMode.BEHAVIOR + ) + + assert source_success is True + assert instrumented_source is not None + assert "@codeflash_behavior_async" in instrumented_source + assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source + assert "codeflash_behavior_async" in instrumented_source + + source_file.write_text(instrumented_source) + + success, instrumented_test_code = inject_profiling_into_existing_test( + test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR + ) + + # For async functions, once source is decorated, test injection should fail + # This is expected behavior - async instrumentation happens at the decorator level + assert success is False + assert instrumented_test_code is None + + +def test_inject_profiling_async_function_performance_mode(temp_dir): + source_module_code = ''' +import asyncio + +async def async_function(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.01) + return x * y +''' + + source_file = temp_dir / "my_module.py" + source_file.write_text(source_module_code) + + # Create the test file + async_test_code = ''' +import asyncio +import pytest +from my_module import async_function + +@pytest.mark.asyncio +async def test_async_function(): + """Test async function performance.""" + result = await async_function(5, 3) + assert result == 15 +''' + + test_file = temp_dir / "test_async.py" + test_file.write_text(async_test_code) + + func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True) + + # First instrument the source module + from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + source_file, func, TestingMode.PERFORMANCE + ) + + assert source_success is True + assert instrumented_source is not None + assert "@codeflash_performance_async" in instrumented_source + # Check for the import with line continuation formatting + assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source + assert "codeflash_performance_async" in instrumented_source + + # Write the instrumented source back + source_file.write_text(instrumented_source) + + # Now test the full pipeline with source module path + success, instrumented_test_code = inject_profiling_into_existing_test( + test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE + ) + + # For async functions, once source is decorated, test injection should fail + # This is expected behavior - async instrumentation happens at the decorator level + assert success is False + assert instrumented_test_code is None + + +def test_mixed_sync_async_instrumentation(temp_dir): + source_module_code = ''' +import asyncio + +def sync_function(x: int, y: int) -> int: + """Regular sync function.""" + return x * y + +async def async_function(x: int, y: int) -> int: + """Simple async function.""" + await asyncio.sleep(0.01) + return x * y +''' + + source_file = temp_dir / "my_module.py" + source_file.write_text(source_module_code) + + mixed_test_code = ''' +import asyncio +import pytest +from my_module import sync_function, async_function + +@pytest.mark.asyncio +async def test_mixed_functions(): + """Test both sync and async functions.""" + sync_result = sync_function(10, 5) + assert sync_result == 50 + + async_result = await async_function(3, 4) + assert async_result == 12 +''' + + test_file = temp_dir / "test_mixed.py" + test_file.write_text(mixed_test_code) + + async_func = FunctionToOptimize( + function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True + ) + + from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + source_file, async_func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + assert "@codeflash_behavior_async" in instrumented_source + assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source + assert "codeflash_behavior_async" in instrumented_source + # Sync function should remain unchanged + assert "def sync_function(x: int, y: int) -> int:" in instrumented_source + + # Write instrumented source back + source_file.write_text(instrumented_source) + + success, instrumented_test_code = inject_profiling_into_existing_test( + test_file, + [CodePosition(8, 18), CodePosition(11, 19)], + async_func, + temp_dir, + "pytest", + mode=TestingMode.BEHAVIOR, + ) + + # Async functions should not be instrumented at the test level + assert not success + assert instrumented_test_code is None + + +def test_async_function_qualified_name_handling(): + nested_async_code = ''' +import asyncio + +class OuterClass: + class InnerClass: + async def nested_async_method(self, x: int) -> int: + """Nested async method.""" + await asyncio.sleep(0.001) + return x * 2 +''' + + func = FunctionToOptimize( + function_name="nested_async_method", + file_path=Path("test_nested.py"), + parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}], + is_async=True, + ) + + modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR) + + expected_output = ( + """import asyncio + +from codeflash.code_utils.codeflash_wrap_decorator import \\ + codeflash_behavior_async + + +class OuterClass: + class InnerClass: + @codeflash_behavior_async + async def nested_async_method(self, x: int) -> int: + \"\"\"Nested async method.\"\"\" + await asyncio.sleep(0.001) + return x * 2 +""" + ) + + assert modified_code.strip() == expected_output.strip() + + +def test_async_decorator_with_existing_decorators(): + """Test async decorator application when function already has other decorators.""" + decorated_async_code = ''' +import asyncio +from functools import wraps + +def my_decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + return await func(*args, **kwargs) + return wrapper + +@my_decorator +async def async_function(x: int, y: int) -> int: + """Async function with existing decorator.""" + await asyncio.sleep(0.01) + return x * y +''' + + func = FunctionToOptimize( + function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True + ) + + modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR) + + assert decorator_added + # Should add codeflash decorator above existing decorators + assert "@codeflash_behavior_async" in modified_code + assert "@my_decorator" in modified_code + # Codeflash decorator should come first + codeflash_pos = modified_code.find("@codeflash_behavior_async") + my_decorator_pos = modified_code.find("@my_decorator") + assert codeflash_pos < my_decorator_pos + + +def test_sync_function_not_affected_by_async_logic(): + sync_function_code = ''' +def sync_function(x: int, y: int) -> int: + """Regular sync function.""" + return x + y +''' + + sync_func = FunctionToOptimize( + function_name="sync_function", + file_path=Path("test_sync.py"), + parents=[], + is_async=False, + ) + + modified_code, decorator_added = add_async_decorator_to_function( + sync_function_code, sync_func, TestingMode.BEHAVIOR + ) + + assert not decorator_added + assert modified_code == sync_function_code + +def test_inject_profiling_async_multiple_calls_same_test(temp_dir): + """Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc.""" + source_module_code = ''' +import asyncio + +async def async_sorter(items): + """Simple async sorter for testing.""" + await asyncio.sleep(0.001) + return sorted(items) +''' + + source_file = temp_dir / "async_sorter.py" + source_file.write_text(source_module_code) + + test_code_multiple_calls = """ +import asyncio +import pytest +from async_sorter import async_sorter + +@pytest.mark.asyncio +async def test_single_call(): + result = await async_sorter([42]) + assert result == [42] + +@pytest.mark.asyncio +async def test_multiple_calls(): + result1 = await async_sorter([3, 1, 2]) + result2 = await async_sorter([5, 4]) + result3 = await async_sorter([9, 8, 7, 6]) + assert result1 == [1, 2, 3] + assert result2 == [4, 5] + assert result3 == [6, 7, 8, 9] +""" + + test_file = temp_dir / "test_async_sorter.py" + test_file.write_text(test_code_multiple_calls) + + func = FunctionToOptimize( + function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True + ) + + # First instrument the source module with async decorators + from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators + + source_success, instrumented_source = instrument_source_module_with_async_decorators( + source_file, func, TestingMode.BEHAVIOR + ) + + assert source_success + assert instrumented_source is not None + assert "@codeflash_behavior_async" in instrumented_source + + source_file.write_text(instrumented_source) + + import ast + + tree = ast.parse(test_code_multiple_calls) + call_positions = [] + for node in ast.walk(tree): + if isinstance(node, ast.Await) and isinstance(node.value, ast.Call): + if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or ( + hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter" + ): + call_positions.append(CodePosition(node.lineno, node.col_offset)) + + assert len(call_positions) == 4 + + success, instrumented_test_code = inject_profiling_into_existing_test( + test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR + ) + + assert success + assert instrumented_test_code is not None + + assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code + + # Count occurrences of each line_id to verify numbering + line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'") + line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'") + line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'") + + + assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}" + assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}" + assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}" + + + +def test_async_behavior_decorator_return_values_and_test_ids(): + """Test that async behavior decorator correctly captures return values, test IDs, and stores data in database.""" + import asyncio + import os + import sqlite3 + from pathlib import Path + + import dill as pickle + + from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async + + @codeflash_behavior_async + async def test_async_multiply(x: int, y: int) -> int: + """Simple async function for testing.""" + await asyncio.sleep(0.001) # Small delay to simulate async work + return x * y + + test_env = { + "CODEFLASH_TEST_MODULE": "test_module", + "CODEFLASH_TEST_CLASS": None, + "CODEFLASH_TEST_FUNCTION": "test_async_multiply_function", + "CODEFLASH_CURRENT_LINE_ID": "0", + "CODEFLASH_LOOP_INDEX": "1", + "CODEFLASH_TEST_ITERATION": "2", + } + + original_env = {k: os.environ.get(k) for k in test_env} + for k, v in test_env.items(): + if v is not None: + os.environ[k] = v + elif k in os.environ: + del os.environ[k] + + try: + result = asyncio.run(test_async_multiply(6, 7)) + + assert result == 42, f"Expected return value 42, got {result}" + + from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file + + db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite")) + + # Verify database exists and has data + assert db_path.exists(), f"Database file not created at {db_path}" + + # Read and verify database contents + con = sqlite3.connect(db_path) + cur = con.cursor() + + cur.execute("SELECT * FROM test_results") + rows = cur.fetchall() + + assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}" + + row = rows[0] + ( + test_module, + test_class, + test_function, + function_name, + loop_index, + iteration_id, + runtime, + return_value_blob, + verification_type, + ) = row + + assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'" + assert test_class is None, f"Expected test_class None, got '{test_class}'" + assert test_function == "test_async_multiply_function", ( + f"Expected test_function 'test_async_multiply_function', got '{test_function}'" + ) + assert function_name == "test_async_multiply", ( + f"Expected function_name 'test_async_multiply', got '{function_name}'" + ) + assert loop_index == 1, f"Expected loop_index 1, got {loop_index}" + assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'" + assert verification_type == "function_call", ( + f"Expected verification_type 'function_call', got '{verification_type}'" + ) + unpickled_data = pickle.loads(return_value_blob) + args, kwargs, actual_return_value = unpickled_data + + assert args == (6, 7), f"Expected args (6, 7), got {args}" + assert kwargs == {}, f"Expected empty kwargs, got {kwargs}" + + assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}" + + con.close() + + finally: + for k, v in original_env.items(): + if v is not None: + os.environ[k] = v + elif k in os.environ: + del os.environ[k] + + +def test_async_decorator_comprehensive_return_values_and_test_ids(): + import asyncio + import os + import sqlite3 + from pathlib import Path + + import dill as pickle + + from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file + + @codeflash_behavior_async + async def async_multiply_add(x: int, y: int, z: int = 1) -> int: + """Async function that multiplies x*y then adds z.""" + await asyncio.sleep(0.001) + result = (x * y) + z + return result + + test_env = { + "CODEFLASH_TEST_MODULE": "test_comprehensive_module", + "CODEFLASH_TEST_CLASS": "AsyncTestClass", + "CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function", + "CODEFLASH_CURRENT_LINE_ID": "3", + "CODEFLASH_LOOP_INDEX": "2", + "CODEFLASH_TEST_ITERATION": "3", + } + + original_env = {k: os.environ.get(k) for k in test_env} + for k, v in test_env.items(): + if v is not None: + os.environ[k] = v + elif k in os.environ: + del os.environ[k] + + try: + test_cases = [ + {"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16 + {"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18 + {"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43 + ] + + results = [] + for test_case in test_cases: + result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"])) + results.append(result) + + # Verify each return value is exactly correct + assert result == test_case["expected"], ( + f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}" + ) + + db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite")) + assert db_path.exists(), f"Database not created at {db_path}" + + con = sqlite3.connect(db_path) + cur = con.cursor() + + cur.execute( + "SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid" + ) + rows = cur.fetchall() + + assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}" + + for i, ( + test_module, + test_class, + test_function, + function_name, + loop_index, + iteration_id, + runtime, + return_value_blob, + verification_type, + ) in enumerate(rows): + assert test_module == "test_comprehensive_module", ( + f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'" + ) + assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'" + assert test_function == "test_comprehensive_async_function", ( + f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'" + ) + assert function_name == "async_multiply_add", ( + f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'" + ) + assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}" + assert verification_type == "function_call", ( + f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'" + ) + + expected_iteration_id = f"3_{i}" + assert iteration_id == expected_iteration_id, ( + f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'" + ) + + + args, kwargs, actual_return_value = pickle.loads(return_value_blob) + expected_args = test_cases[i]["args"] + expected_kwargs = test_cases[i]["kwargs"] + expected_return = test_cases[i]["expected"] + + assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}" + assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}" + assert actual_return_value == expected_return, ( + f"Row {i}: Expected return value {expected_return}, got {actual_return_value}" + ) + + con.close() + + finally: + for k, v in original_env.items(): + if v is not None: + os.environ[k] = v + elif k in os.environ: + del os.environ[k]