Merge branch 'main' into part-1-windows-fixes

This commit is contained in:
Kevin Turcios 2025-09-29 14:46:25 -07:00
commit f978a406bb
45 changed files with 5834 additions and 838 deletions

69
.github/workflows/e2e-async.yaml vendored Normal file
View file

@ -0,0 +1,69 @@
name: E2E - Async
on:
pull_request:
paths:
- '**' # Trigger for all paths
workflow_dispatch:
jobs:
async-optimization:
# Dynamically determine if environment is needed only when workflow files change and contributor is external
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 10
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v5
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: |
uv sync
- name: Run Codeflash to optimize async code
id: optimize_async_code
run: |
uv run python tests/scripts/end_to_end_test_async.py

View file

@ -20,7 +20,7 @@ jobs:
COLUMNS: 110 COLUMNS: 110
MAX_RETRIES: 3 MAX_RETRIES: 3
RETRY_DELAY: 5 RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 300 EXPECTED_IMPROVEMENT_PCT: 70
CODEFLASH_END_TO_END: 1 CODEFLASH_END_TO_END: 1
steps: steps:
- name: 🛎️ Checkout - name: 🛎️ Checkout

View file

@ -20,7 +20,7 @@ jobs:
COLUMNS: 110 COLUMNS: 110
MAX_RETRIES: 3 MAX_RETRIES: 3
RETRY_DELAY: 5 RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 300 EXPECTED_IMPROVEMENT_PCT: 40
CODEFLASH_END_TO_END: 1 CODEFLASH_END_TO_END: 1
steps: steps:
- name: 🛎️ Checkout - name: 🛎️ Checkout

View file

@ -0,0 +1,43 @@
import asyncio
from typing import List, Union
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
"""
Async bubble sort implementation for testing.
"""
print("codeflash stdout: Async sorting list")
await asyncio.sleep(0.01)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
print(f"result: {result}")
return result
class AsyncBubbleSorter:
"""Class with async sorting method for testing."""
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
"""
Async bubble sort implementation within a class.
"""
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
# Add some async delay
await asyncio.sleep(0.005)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
return result

View file

@ -0,0 +1,16 @@
import time
import asyncio
async def retry_with_backoff(func, max_retries=3):
if max_retries < 1:
raise ValueError("max_retries must be at least 1")
last_exception = None
for attempt in range(max_retries):
try:
return await func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
time.sleep(0.0001 * attempt)
raise last_exception

View file

@ -0,0 +1,6 @@
[tool.codeflash]
disable-telemetry = true
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
module-root = "."
test-framework = "pytest"
tests-root = "tests"

View file

@ -102,6 +102,8 @@ class AiServiceClient:
trace_id: str, trace_id: str,
num_candidates: int = 10, num_candidates: int = 10,
experiment_metadata: ExperimentMetadata | None = None, experiment_metadata: ExperimentMetadata | None = None,
*,
is_async: bool = False,
) -> list[OptimizedCandidate]: ) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint. """Optimize the given python code for performance by making a request to the Django endpoint.
@ -133,6 +135,7 @@ class AiServiceClient:
"repo_owner": git_repo_owner, "repo_owner": git_repo_owner,
"repo_name": git_repo_name, "repo_name": git_repo_name,
"n_candidates": N_CANDIDATES_EFFECTIVE, "n_candidates": N_CANDIDATES_EFFECTIVE,
"is_async": is_async,
} }
logger.info("!lsp|Generating optimized candidates…") logger.info("!lsp|Generating optimized candidates…")
@ -299,6 +302,9 @@ class AiServiceClient:
annotated_tests: str, annotated_tests: str,
optimization_id: str, optimization_id: str,
original_explanation: str, original_explanation: str,
original_throughput: str | None = None,
optimized_throughput: str | None = None,
throughput_improvement: str | None = None,
) -> str: ) -> str:
"""Optimize the given python code for performance by making a request to the Django endpoint. """Optimize the given python code for performance by making a request to the Django endpoint.
@ -315,6 +321,9 @@ class AiServiceClient:
- annotated_tests: str - test functions annotated with runtime - annotated_tests: str - test functions annotated with runtime
- optimization_id: str - unique id of opt candidate - optimization_id: str - unique id of opt candidate
- original_explanation: str - original_explanation generated for the opt candidate - original_explanation: str - original_explanation generated for the opt candidate
- original_throughput: str | None - throughput for the baseline code (operations per second)
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
- throughput_improvement: str | None - throughput improvement percentage
Returns Returns
------- -------
@ -334,6 +343,9 @@ class AiServiceClient:
"optimization_id": optimization_id, "optimization_id": optimization_id,
"original_explanation": original_explanation, "original_explanation": original_explanation,
"dependency_code": dependency_code, "dependency_code": dependency_code,
"original_throughput": original_throughput,
"optimized_throughput": optimized_throughput,
"throughput_improvement": throughput_improvement,
} }
logger.info("loading|Generating explanation") logger.info("loading|Generating explanation")
console.rule() console.rule()
@ -488,6 +500,7 @@ class AiServiceClient:
"test_index": test_index, "test_index": test_index,
"python_version": platform.python_version(), "python_version": platform.python_version(),
"codeflash_version": codeflash_version, "codeflash_version": codeflash_version,
"is_async": function_to_optimize.is_async,
} }
try: try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600) response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)

View file

@ -1,3 +1,4 @@
import importlib.util
import logging import logging
import sys import sys
from argparse import SUPPRESS, ArgumentParser, Namespace from argparse import SUPPRESS, ArgumentParser, Namespace
@ -96,6 +97,12 @@ def parse_args() -> Namespace:
) )
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs") parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization") parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
"--async",
default=False,
action="store_true",
help="Enable optimization of async functions. By default, async functions are excluded from optimization.",
)
args, unknown_args = parser.parse_known_args() args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args] sys.argv[:] = [sys.argv[0], *unknown_args]
@ -139,6 +146,14 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
if env_utils.is_ci(): if env_utils.is_ci():
args.no_pr = True args.no_pr = True
if getattr(args, "async", False) and importlib.util.find_spec("pytest_asyncio") is None:
logger.warning(
"Warning: The --async flag requires pytest-asyncio to be installed.\n"
"Please install it using:\n"
' pip install "codeflash[asyncio]"'
)
raise SystemExit(1)
return args return args

View file

@ -272,6 +272,8 @@ class DottedImportCollector(cst.CSTVisitor):
if child.module is None: if child.module is None:
continue continue
module = self.get_full_dotted_name(child.module) module = self.get_full_dotted_name(child.module)
if isinstance(child.names, cst.ImportStar):
continue
for alias in child.names: for alias in child.names:
if isinstance(alias, cst.ImportAlias): if isinstance(alias, cst.ImportAlias):
name = alias.name.value name = alias.name.value
@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
return transformed_module.code return transformed_module.code
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
try:
module_path = module_name.replace(".", "/")
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
module_file = None
for path in possible_paths:
if path.exists():
module_file = path
break
if module_file is None:
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
return set()
with module_file.open(encoding="utf8") as f:
module_code = f.read()
tree = ast.parse(module_code)
all_names = None
for node in ast.walk(tree):
if (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "__all__"
):
if isinstance(node.value, (ast.List, ast.Tuple)):
all_names = []
for elt in node.value.elts:
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
all_names.append(elt.value)
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
all_names.append(elt.s)
break
if all_names is not None:
return set(all_names)
public_names = set()
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
if not node.name.startswith("_"):
public_names.add(node.name)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and not target.id.startswith("_"):
public_names.add(target.id)
elif isinstance(node, ast.AnnAssign):
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
public_names.add(node.target.id)
elif isinstance(node, ast.Import) or (
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
):
for alias in node.names:
name = alias.asname or alias.name
if not name.startswith("_"):
public_names.add(name)
return public_names # noqa: TRY300
except Exception as e:
logger.warning(f"Error resolving star import for {module_name}: {e}")
return set()
def add_needed_imports_from_module( def add_needed_imports_from_module(
src_module_code: str, src_module_code: str,
dst_module_code: str, dst_module_code: str,
@ -468,9 +537,23 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
): ):
continue # Skip adding imports for helper functions already in the context continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj) # Handle star imports by resolving them to actual symbol names
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) if obj == "*":
resolved_symbols = resolve_star_import(mod, project_root)
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
for symbol in resolved_symbols:
if (
f"{mod}.{symbol}" not in helper_functions_fqn
and f"{mod}.{symbol}" not in dotted_import_collector.imports
):
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
else:
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e: except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}") logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code return dst_module_code

View file

@ -269,14 +269,6 @@ def validate_python_code(code: str) -> str:
return code return code
def has_any_async_functions(code: str) -> bool:
try:
module = ast.parse(code)
except SyntaxError:
return False
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))
def cleanup_paths(paths: list[Path]) -> None: def cleanup_paths(paths: list[Path]) -> None:
for path in paths: for path in paths:
if path and path.exists(): if path and path.exists():

View file

@ -0,0 +1,167 @@
from __future__ import annotations
import asyncio
import gc
import os
import sqlite3
from enum import Enum
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, TypeVar
import dill as pickle
class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py
FUNCTION_CALL = (
"function_call" # Correctness verification for a test function, checks input values and output values)
)
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
F = TypeVar("F", bound=Callable[..., Any])
def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
def extract_test_context_from_env() -> tuple[str, str | None, str]:
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)
def codeflash_behavior_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs) # coroutine creation has some overhead, though it is very small
counter = loop.time()
return_value = await ret # let's measure the actual execution time of the code
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
VerificationType.FUNCTION_CALL.value,
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_performance_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper

View file

@ -3,6 +3,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60 MAX_FUNCTION_TEST_SECONDS = 60
N_CANDIDATES = 5 N_CANDIDATES = 5
MIN_IMPROVEMENT_THRESHOLD = 0.05 MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
MAX_TEST_FUNCTION_RUNS = 50 MAX_TEST_FUNCTION_RUNS = 50
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
N_TESTS_TO_GENERATE = 2 N_TESTS_TO_GENERATE = 2

View file

@ -14,7 +14,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
"""Extract the single dependent function from the code context excluding the main function.""" """Extract the single dependent function from the code context excluding the main function."""
ast_tree = ast.parse(code_context.testgen_context_code) ast_tree = ast.parse(code_context.testgen_context_code)
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)} dependent_functions = {
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
}
if main_function in dependent_functions: if main_function in dependent_functions:
dependent_functions.discard(main_function) dependent_functions.discard(main_function)

View file

@ -32,9 +32,11 @@ class CommentMapper(ast.NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
self.context_stack.append(node.name) self.context_stack.append(node.name)
for inner_node in ast.walk(node): for inner_node in node.body:
if isinstance(inner_node, ast.FunctionDef): if isinstance(inner_node, ast.FunctionDef):
self.visit_FunctionDef(inner_node) self.visit_FunctionDef(inner_node)
elif isinstance(inner_node, ast.AsyncFunctionDef):
self.visit_AsyncFunctionDef(inner_node)
self.context_stack.pop() self.context_stack.pop()
return node return node
@ -50,6 +52,14 @@ class CommentMapper(ast.NodeVisitor):
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
self._process_function_def_common(node)
return node
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
self._process_function_def_common(node)
return node
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
self.context_stack.append(node.name) self.context_stack.append(node.name)
i = len(node.body) - 1 i = len(node.body) - 1
test_qualified_name = ".".join(self.context_stack) test_qualified_name = ".".join(self.context_stack)
@ -60,8 +70,9 @@ class CommentMapper(ast.NodeVisitor):
j = len(line_node.body) - 1 j = len(line_node.body) - 1
while j >= 0: while j >= 0:
compound_line_node: ast.stmt = line_node.body[j] compound_line_node: ast.stmt = line_node.body[j]
internal_node: ast.AST nodes_to_check = [compound_line_node]
for internal_node in ast.walk(compound_line_node): nodes_to_check.extend(getattr(compound_line_node, "body", []))
for internal_node in nodes_to_check:
if isinstance(internal_node, (ast.stmt, ast.Assign)): if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j) inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id match_key = key + "#" + inv_id
@ -75,7 +86,6 @@ class CommentMapper(ast.NodeVisitor):
self.results[line_node.lineno] = self.get_comment(match_key) self.results[line_node.lineno] = self.get_comment(match_key)
i -= 1 i -= 1
self.context_stack.pop() self.context_stack.pop()
return node
def get_fn_call_linenos( def get_fn_call_linenos(
@ -197,23 +207,41 @@ def add_runtime_comments_to_generated_tests(
def remove_functions_from_generated_tests( def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str] generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList: ) -> GeneratedTestsList:
# Pre-compile patterns for all function names to remove
function_patterns = _compile_function_patterns(test_functions_to_remove)
new_generated_tests = [] new_generated_tests = []
for generated_test in generated_tests.generated_tests: for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove: source = generated_test.generated_original_test_source
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)
match = function_pattern.search(generated_test.generated_original_test_source) # Apply all patterns without redundant searches
for pattern in function_patterns:
if match is None or "@pytest.mark.parametrize" in match.group(0): # Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls
continue for match in pattern.finditer(source):
# Skip if "@pytest.mark.parametrize" present
generated_test.generated_original_test_source = function_pattern.sub( # Only the matched function's code is targeted
"", generated_test.generated_original_test_source if "@pytest.mark.parametrize" in match.group(0):
) continue
# Remove function from source
# If match, remove the function by substitution in the source
# Replace using start/end indices for efficiency
start, end = match.span()
source = source[:start] + source[end:]
# After removal, break since .finditer() is from left to right, and only one match expected per function in source
break
generated_test.generated_original_test_source = source
new_generated_tests.append(generated_test) new_generated_tests.append(generated_test)
return GeneratedTestsList(generated_tests=new_generated_tests) return GeneratedTestsList(generated_tests=new_generated_tests)
# Pre-compile all function removal regexes upfront for efficiency.
def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]:
return [
re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
re.DOTALL,
)
for func in test_functions_to_remove
]

View file

@ -6,6 +6,7 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import isort import isort
import libcst as cst
from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
@ -77,6 +78,10 @@ class InjectPerfOnly(ast.NodeTransformer):
call_node = node call_node = node
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
function_name = node.func.id function_name = node.func.id
if self.function_object.is_async:
return [test_node]
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [ node.args = [
ast.Name(id=function_name, ctx=ast.Load()), ast.Name(id=function_name, ctx=ast.Load()),
@ -98,6 +103,9 @@ class InjectPerfOnly(ast.NodeTransformer):
if isinstance(node.func, ast.Attribute): if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr function_to_test = node.func.attr
if function_to_test == self.function_object.function_name: if function_to_test == self.function_object.function_name:
if self.function_object.is_async:
return [test_node]
function_name = ast.unparse(node.func) function_name = ast.unparse(node.func)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load()) node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [ node.args = [
@ -289,6 +297,168 @@ class InjectPerfOnly(ast.NodeTransformer):
return node return node
class AsyncCallInstrumenter(ast.NodeTransformer):
def __init__(
self,
function: FunctionToOptimize,
module_path: str,
test_framework: str,
call_positions: list[CodePosition],
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
self.mode = mode
self.function_object = function
self.class_name = None
self.only_function_name = function.function_name
self.module_path = module_path
self.test_framework = test_framework
self.call_positions = call_positions
self.did_instrument = False
# Track function call count per test function
self.async_call_counter: dict[str, int] = {}
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# Add timeout decorator for unittest test classes if needed
if self.test_framework == "unittest":
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
and item.name.startswith("test_")
and not any(
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
for d in item.decorator_list
)
):
item.decorator_list.append(timeout_decorator)
return self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
if not node.name.startswith("test_"):
return node
return self._process_test_function(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
# Only process test functions
if not node.name.startswith("test_"):
return node
return self._process_test_function(node)
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
# Optimize the search for decorator presence
if self.test_framework == "unittest":
found_timeout = False
for d in node.decorator_list:
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
if isinstance(d, ast.Call):
f = d.func
# Avoid attribute lookup if f is not ast.Name
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
found_timeout = True
break
if not found_timeout:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)
# Initialize counter for this test function
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0
new_body = []
# Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
for _i, stmt in enumerate(node.body):
transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt)
if added_env_assignment:
current_call_index = self.async_call_counter[node.name]
self.async_call_counter[node.name] += 1
env_assignment = ast.Assign(
targets=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"),
ctx=ast.Store(),
)
],
value=ast.Constant(value=f"{current_call_index}"),
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
)
new_body.append(env_assignment)
self.did_instrument = True
new_body.append(transformed_stmt)
node.body = new_body
return node
def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
for node in ast.walk(stmt):
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
return stmt, True # Return original statement but signal we added env var
return stmt, False
def _is_target_call(self, call_node: ast.Call) -> bool:
"""Check if this call node is calling our target async function."""
if isinstance(call_node.func, ast.Name):
return call_node.func.id == self.function_object.function_name
if isinstance(call_node.func, ast.Attribute):
return call_node.func.attr == self.function_object.function_name
return False
def _call_in_positions(self, call_node: ast.Call) -> bool:
if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"):
return False
return node_in_call_position(call_node, self.call_positions)
# Optimized version: only walk child nodes for Await
def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]:
# Stack-based DFS, manual for relevant Await nodes
stack = [stmt]
while stack:
node = stack.pop()
# Favor direct ast.Await detection
if isinstance(node, ast.Await):
val = node.value
if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val):
return stmt, True
# Use _fields instead of ast.walk for less allocations
for fname in getattr(node, "_fields", ()):
child = getattr(node, fname, None)
if isinstance(child, list):
stack.extend(child)
elif isinstance(child, ast.AST):
stack.append(child)
return stmt, False
class FunctionImportedAsVisitor(ast.NodeVisitor): class FunctionImportedAsVisitor(ast.NodeVisitor):
"""Checks if a function has been imported as an alias. We only care about the alias then. """Checks if a function has been imported as an alias. We only care about the alias then.
@ -316,6 +486,7 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
file_path=self.function.file_path, file_path=self.function.file_path,
starting_line=self.function.starting_line, starting_line=self.function.starting_line,
ending_line=self.function.ending_line, ending_line=self.function.ending_line,
is_async=self.function.is_async,
) )
else: else:
self.imported_as = FunctionToOptimize( self.imported_as = FunctionToOptimize(
@ -324,9 +495,69 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
file_path=self.function.file_path, file_path=self.function.file_path,
starting_line=self.function.starting_line, starting_line=self.function.starting_line,
ending_line=self.function.ending_line, ending_line=self.function.ending_line,
is_async=self.function.is_async,
) )
def instrument_source_module_with_async_decorators(
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
) -> tuple[bool, str | None]:
if not function_to_optimize.is_async:
return False, None
try:
with source_path.open(encoding="utf8") as f:
source_code = f.read()
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
if decorator_added:
return True, modified_code
except Exception:
return False, None
else:
return False, None
def inject_async_profiling_into_existing_test(
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
"""Inject profiling for async function calls by setting environment variables before each call."""
with test_path.open(encoding="utf8") as f:
test_code = f.read()
try:
tree = ast.parse(test_code)
except SyntaxError:
logger.exception(f"Syntax error in code in file - {test_path}")
return False, None
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
test_module_path = module_name_from_file_path(test_path, tests_project_root)
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree)
func = import_visitor.imported_as
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
tree = async_instrumenter.visit(tree)
if not async_instrumenter.did_instrument:
return False, None
# Add necessary imports
new_imports = [ast.Import(names=[ast.alias(name="os")])]
if test_framework == "unittest":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, *tree.body]
return True, isort.code(ast.unparse(tree), float_to_top=True)
def inject_profiling_into_existing_test( def inject_profiling_into_existing_test(
test_path: Path, test_path: Path,
call_positions: list[CodePosition], call_positions: list[CodePosition],
@ -335,6 +566,11 @@ def inject_profiling_into_existing_test(
test_framework: str, test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR, mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
)
with test_path.open(encoding="utf8") as f: with test_path.open(encoding="utf8") as f:
test_code = f.read() test_code = f.read()
try: try:
@ -342,7 +578,7 @@ def inject_profiling_into_existing_test(
except SyntaxError: except SyntaxError:
logger.exception(f"Syntax error in code in file - {test_path}") logger.exception(f"Syntax error in code in file - {test_path}")
return False, None return False, None
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
test_module_path = module_name_from_file_path(test_path, tests_project_root) test_module_path = module_name_from_file_path(test_path, tests_project_root)
import_visitor = FunctionImportedAsVisitor(function_to_optimize) import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree) import_visitor.visit(tree)
@ -360,7 +596,9 @@ def inject_profiling_into_existing_test(
) )
if test_framework == "unittest" and platform.system() != "Windows": if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body] additional_functions = [create_wrapper_function(mode)]
tree.body = [*new_imports, *additional_functions, *tree.body]
return True, isort.code(ast.unparse(tree), float_to_top=True) return True, isort.code(ast.unparse(tree), float_to_top=True)
@ -741,3 +979,162 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
decorator_list=[], decorator_list=[],
returns=None, returns=None,
) )
class AsyncDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds async decorator to async function definitions."""
def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
"""Initialize the transformer.
Args:
----
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.
"""
super().__init__()
self.function = function
self.mode = mode
self.qualified_name_parts = function.qualified_name.split(".")
self.context_stack = []
self.added_decorator = False
# Choose decorator based on mode
self.decorator_name = (
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
# Track when we enter a function
self.context_stack.append(node.name.value)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Check if this is an async function and matches our target
if original_node.asynchronous is not None and self.context_stack == self.qualified_name_parts:
# Check if the decorator is already present
has_decorator = any(
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
)
# Only add the decorator if it's not already there
if not has_decorator:
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
# Add our new decorator to the existing decorators
updated_decorators = [new_decorator, *list(updated_node.decorators)]
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))
self.added_decorator = True
# Pop the context when we leave a function
self.context_stack.pop()
return updated_node
def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool:
"""Check if a decorator matches our target decorator name."""
if isinstance(decorator_node, cst.Name):
return decorator_node.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
}
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
return decorator_node.func.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
}
return False
class AsyncDecoratorImportAdder(cst.CSTTransformer):
"""Transformer that adds the import for async decorators."""
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
self.mode = mode
self.has_import = False
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
isinstance(node.module, cst.Attribute)
and isinstance(node.module.value, cst.Attribute)
and isinstance(node.module.value.value, cst.Name)
and node.module.value.value.value == "codeflash"
and node.module.value.attr.value == "code_utils"
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# If the import is already there, don't add it again
if self.has_import:
return updated_node
# Choose import based on mode
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
def add_async_decorator_to_function(
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
) -> tuple[str, bool]:
"""Add async decorator to an async function definition.
Args:
----
source_code: The source code to modify.
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.
Returns:
-------
Tuple of (modified_source_code, was_decorator_added).
"""
if not function.is_async:
return source_code, False
try:
module = cst.parse_module(source_code)
# Add the decorator to the function
decorator_transformer = AsyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)
# Add the import if decorator was added
if decorator_transformer.added_decorator:
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)
return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
except Exception as e:
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
return source_code, False
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:
instrumented_filename = f"instrumented_{source_path.name}"
return temp_dir / instrumented_filename

View file

@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(
def get_first_top_level_function_or_method_ast( def get_first_top_level_function_or_method_ast(
function_name: str, parents: list[FunctionParent], node: ast.AST function_name: str, parents: list[FunctionParent], node: ast.AST
) -> ast.FunctionDef | None: ) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
if not parents: if not parents:
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node) result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
if parents[0].type == "ClassDef" and ( if parents[0].type == "ClassDef" and (
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node) class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
): ):
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node) result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
return None return None

View file

@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor):
parents=list(reversed(ast_parents)), parents=list(reversed(ast_parents)),
starting_line=pos.start.line, starting_line=pos.start.line,
ending_line=pos.end.line, ending_line=pos.end.line,
is_async=bool(node.asynchronous),
) )
) )
@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
) )
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
)
)
def generic_visit(self, node: ast.AST) -> None: def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)): if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__)) self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
@ -122,6 +132,7 @@ class FunctionToOptimize:
parents: A list of parent scopes, which could be classes or functions. parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file. starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file. ending_line: The ending line number of the function in the file.
is_async: Whether this function is defined as async.
The qualified_name property provides the full name of the function, including The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root any parent class or function names. The qualified_name_with_modules_from_root
@ -134,6 +145,7 @@ class FunctionToOptimize:
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None starting_line: Optional[int] = None
ending_line: Optional[int] = None ending_line: Optional[int] = None
is_async: bool = False
@property @property
def top_level_parent_name(self) -> str: def top_level_parent_name(self) -> str:
@ -147,7 +159,11 @@ class FunctionToOptimize:
@property @property
def qualified_name(self) -> str: def qualified_name(self) -> str:
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}" if not self.parents:
return self.function_name
# Join all parent names with dots to handle nested classes properly
parent_path = ".".join(parent.name for parent in self.parents)
return f"{parent_path}.{self.function_name}"
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
@ -163,6 +179,8 @@ def get_functions_to_optimize(
project_root: Path, project_root: Path,
module_root: Path, module_root: Path,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None, previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
*,
enable_async: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]: ) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided" "Only one of optimize_all, replay_test, or file should be provided"
@ -216,7 +234,13 @@ def get_functions_to_optimize(
ph("cli-optimizing-git-diff") ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff(uncommitted_changes=False) functions = get_functions_within_git_diff(uncommitted_changes=False)
filtered_modified_functions, functions_count = filter_functions( filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions functions,
test_cfg.tests_root,
ignore_paths,
project_root,
module_root,
previous_checkpoint_functions,
enable_async=enable_async,
) )
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize") logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
@ -411,11 +435,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
) )
) )
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if self.class_name is None and node.name == self.function_name:
self.is_top_level = True
self.function_has_args = any(
(
bool(node.args.args),
bool(node.args.kwonlyargs),
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
)
)
def visit_ClassDef(self, node: ast.ClassDef) -> None: def visit_ClassDef(self, node: ast.ClassDef) -> None:
# iterate over the class methods # iterate over the class methods
if node.name == self.class_name: if node.name == self.class_name:
for body_node in node.body: for body_node in node.body:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name: if (
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
):
self.is_top_level = True self.is_top_level = True
if any( if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod" isinstance(decorator, ast.Name) and decorator.id == "classmethod"
@ -433,7 +473,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
# This way, if we don't have the class name, we can still find the static method # This way, if we don't have the class name, we can still find the static method
for body_node in node.body: for body_node in node.body:
if ( if (
isinstance(body_node, ast.FunctionDef) isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1} and body_node.lineno in {self.line_no, self.line_no + 1}
and any( and any(
@ -535,7 +575,9 @@ def filter_functions(
project_root: Path, project_root: Path,
module_root: Path, module_root: Path,
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None, previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
disable_logs: bool = False, # noqa: FBT001, FBT002 *,
disable_logs: bool = False,
enable_async: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]: ) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {} filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
blocklist_funcs = get_blocklisted_functions() blocklist_funcs = get_blocklisted_functions()
@ -555,6 +597,7 @@ def filter_functions(
submodule_ignored_paths_count: int = 0 submodule_ignored_paths_count: int = 0
blocklist_funcs_removed_count: int = 0 blocklist_funcs_removed_count: int = 0
previous_checkpoint_functions_removed_count: int = 0 previous_checkpoint_functions_removed_count: int = 0
async_functions_removed_count: int = 0
tests_root_str = str(tests_root) tests_root_str = str(tests_root)
module_root_str = str(module_root) module_root_str = str(module_root)
@ -610,6 +653,15 @@ def filter_functions(
functions_tmp.append(function) functions_tmp.append(function)
_functions = functions_tmp _functions = functions_tmp
if not enable_async:
functions_tmp = []
for function in _functions:
if function.is_async:
async_functions_removed_count += 1
continue
functions_tmp.append(function)
_functions = functions_tmp
filtered_modified_functions[file_path] = _functions filtered_modified_functions[file_path] = _functions
functions_count += len(_functions) functions_count += len(_functions)
@ -623,6 +675,7 @@ def filter_functions(
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"), "Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"), "Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"), "Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
"Async functions removed": (async_functions_removed_count, "bright_magenta"),
} }
tree = Tree(Text("Ignored functions and files", style="bold")) tree = Tree(Text("Ignored functions and files", style="bold"))
for label, (count, color) in log_info.items(): for label, (count, color) in log_info.items():

View file

@ -103,6 +103,7 @@ class BestOptimization(BaseModel):
winning_benchmarking_test_results: TestResults winning_benchmarking_test_results: TestResults
winning_replay_benchmarking_test_results: Optional[TestResults] = None winning_replay_benchmarking_test_results: Optional[TestResults] = None
line_profiler_test_results: dict line_profiler_test_results: dict
async_throughput: Optional[int] = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -277,6 +278,7 @@ class OptimizedCandidateResult(BaseModel):
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
optimization_candidate_index: int optimization_candidate_index: int
total_candidate_timing: int total_candidate_timing: int
async_throughput: Optional[int] = None
class GeneratedTests(BaseModel): class GeneratedTests(BaseModel):
@ -383,6 +385,7 @@ class OriginalCodeBaseline(BaseModel):
line_profile_results: dict line_profile_results: dict
runtime: int runtime: int
coverage_results: Optional[CoverageData] coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
class CoverageStatus(Enum): class CoverageStatus(Enum):
@ -545,6 +548,7 @@ class TestResults(BaseModel): # noqa: PLW1641
# also we don't support deletion of test results elements - caution is advised # also we don't support deletion of test results elements - caution is advised
test_results: list[FunctionTestInvocation] = [] test_results: list[FunctionTestInvocation] = []
test_result_idx: dict[str, int] = {} test_result_idx: dict[str, int] = {}
perf_stdout: Optional[str] = None
def add(self, function_test_invocation: FunctionTestInvocation) -> None: def add(self, function_test_invocation: FunctionTestInvocation) -> None:
unique_id = function_test_invocation.unique_invocation_loop_id unique_id = function_test_invocation.unique_invocation_loop_id

View file

@ -36,7 +36,6 @@ from codeflash.code_utils.code_utils import (
diff_length, diff_length,
file_name_from_test_module_name, file_name_from_test_module_name,
get_run_tmp_file, get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path, module_name_from_file_path,
restore_conftest, restore_conftest,
unified_diff_strings, unified_diff_strings,
@ -85,14 +84,20 @@ from codeflash.models.models import (
TestType, TestType,
) )
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.critic import (
coverage_critic,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
throughput_gain,
)
from codeflash.result.explanation import Explanation from codeflash.result.explanation import Explanation
from codeflash.telemetry.posthog_cf import ph from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import parse_test_results from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests from codeflash.verification.verifier import generate_tests
@ -199,7 +204,7 @@ class FunctionOptimizer:
test_cfg: TestConfig, test_cfg: TestConfig,
function_to_optimize_source_code: str = "", function_to_optimize_source_code: str = "",
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None, function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
aiservice_client: AiServiceClient | None = None, aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[BenchmarkKey, int] | None = None, function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
total_benchmark_timings: dict[BenchmarkKey, int] | None = None, total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
@ -259,11 +264,6 @@ class FunctionOptimizer:
helper_code = f.read() helper_code = f.read()
original_helper_code[helper_function_path] = helper_code original_helper_code[helper_function_path] = helper_code
async_code = any(
has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings
)
if async_code:
return Failure("Codeflash does not support async functions in the code to optimize.")
# Random here means that we still attempt optimization with a fractional chance to see if # Random here means that we still attempt optimization with a fractional chance to see if
# last time we could not find an optimization, maybe this time we do. # last time we could not find an optimization, maybe this time we do.
# Random is before as a performance optimization, swapping the two 'and' statements has the same effect # Random is before as a performance optimization, swapping the two 'and' statements has the same effect
@ -588,7 +588,11 @@ class FunctionOptimizer:
tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛") tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛")
benchmark_tree = None benchmark_tree = None
if speedup_critic( if speedup_critic(
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None candidate_result,
original_code_baseline.runtime,
best_runtime_until_now=None,
original_async_throughput=original_code_baseline.async_throughput,
best_throughput_until_now=None,
) and quantity_of_tests_critic(candidate_result): ) and quantity_of_tests_critic(candidate_result):
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
@ -599,6 +603,17 @@ class FunctionOptimizer:
) )
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
if (
original_code_baseline.async_throughput is not None
and candidate_result.async_throughput is not None
):
throughput_gain_value = throughput_gain(
original_throughput=original_code_baseline.async_throughput,
optimized_throughput=candidate_result.async_throughput,
)
tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions")
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
line_profile_test_results = self.line_profiler_step( line_profile_test_results = self.line_profiler_step(
code_context=code_context, code_context=code_context,
original_helper_code=original_helper_code, original_helper_code=original_helper_code,
@ -634,6 +649,7 @@ class FunctionOptimizer:
replay_performance_gain=replay_perf_gain if self.args.benchmark else None, replay_performance_gain=replay_perf_gain if self.args.benchmark else None,
winning_benchmarking_test_results=candidate_result.benchmarking_test_results, winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results, winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
async_throughput=candidate_result.async_throughput,
) )
valid_optimizations.append(best_optimization) valid_optimizations.append(best_optimization)
# queue corresponding refined optimization for best optimization # queue corresponding refined optimization for best optimization
@ -658,6 +674,15 @@ class FunctionOptimizer:
) )
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
if (
original_code_baseline.async_throughput is not None
and candidate_result.async_throughput is not None
):
throughput_gain_value = throughput_gain(
original_throughput=original_code_baseline.async_throughput,
optimized_throughput=candidate_result.async_throughput,
)
tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%")
if is_LSP_enabled(): if is_LSP_enabled():
lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree))) lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree)))
@ -701,6 +726,7 @@ class FunctionOptimizer:
replay_performance_gain=valid_opt.replay_performance_gain, replay_performance_gain=valid_opt.replay_performance_gain,
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results, winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results, winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
async_throughput=valid_opt.async_throughput,
) )
valid_candidates_with_shorter_code.append(new_best_opt) valid_candidates_with_shorter_code.append(new_best_opt)
diff_lens_list.append( diff_lens_list.append(
@ -1080,6 +1106,7 @@ class FunctionOptimizer:
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
n_candidates, n_candidates,
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None, ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
is_async=self.function_to_optimize.is_async,
) )
future_candidates_exp = None future_candidates_exp = None
@ -1095,6 +1122,7 @@ class FunctionOptimizer:
self.function_trace_id[:-4] + "EXP1", self.function_trace_id[:-4] + "EXP1",
n_candidates, n_candidates,
ExperimentMetadata(id=self.experiment_id, group="experiment"), ExperimentMetadata(id=self.experiment_id, group="experiment"),
is_async=self.function_to_optimize.is_async,
) )
futures.append(future_candidates_exp) futures.append(future_candidates_exp)
@ -1281,6 +1309,8 @@ class FunctionOptimizer:
function_name=function_to_optimize_qualified_name, function_name=function_to_optimize_qualified_name,
file_path=self.function_to_optimize.file_path, file_path=self.function_to_optimize.file_path,
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
original_async_throughput=original_code_baseline.async_throughput,
best_async_throughput=best_optimization.async_throughput,
) )
self.replace_function_and_helpers_with_optimized_code( self.replace_function_and_helpers_with_optimized_code(
@ -1363,6 +1393,23 @@ class FunctionOptimizer:
original_runtimes_all=original_runtime_by_test, original_runtimes_all=original_runtime_by_test,
optimized_runtimes_all=optimized_runtime_by_test, optimized_runtimes_all=optimized_runtime_by_test,
) )
original_throughput_str = None
optimized_throughput_str = None
throughput_improvement_str = None
if (
self.function_to_optimize.is_async
and original_code_baseline.async_throughput is not None
and best_optimization.async_throughput is not None
):
original_throughput_str = f"{original_code_baseline.async_throughput} operations/second"
optimized_throughput_str = f"{best_optimization.async_throughput} operations/second"
throughput_improvement_value = throughput_gain(
original_throughput=original_code_baseline.async_throughput,
optimized_throughput=best_optimization.async_throughput,
)
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
new_explanation_raw_str = self.aiservice_client.get_new_explanation( new_explanation_raw_str = self.aiservice_client.get_new_explanation(
source_code=code_context.read_writable_code.flat, source_code=code_context.read_writable_code.flat,
dependency_code=code_context.read_only_context_code, dependency_code=code_context.read_only_context_code,
@ -1376,6 +1423,9 @@ class FunctionOptimizer:
annotated_tests=generated_tests_str, annotated_tests=generated_tests_str,
optimization_id=best_optimization.candidate.optimization_id, optimization_id=best_optimization.candidate.optimization_id,
original_explanation=best_optimization.candidate.explanation, original_explanation=best_optimization.candidate.explanation,
original_throughput=original_throughput_str,
optimized_throughput=optimized_throughput_str,
throughput_improvement=throughput_improvement_str,
) )
new_explanation = Explanation( new_explanation = Explanation(
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message, raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
@ -1386,6 +1436,8 @@ class FunctionOptimizer:
function_name=explanation.function_name, function_name=explanation.function_name,
file_path=explanation.file_path, file_path=explanation.file_path,
benchmark_details=explanation.benchmark_details, benchmark_details=explanation.benchmark_details,
original_async_throughput=explanation.original_async_throughput,
best_async_throughput=explanation.best_async_throughput,
) )
self.log_successful_optimization(new_explanation, generated_tests, exp_type) self.log_successful_optimization(new_explanation, generated_tests, exp_type)
@ -1476,6 +1528,17 @@ class FunctionOptimizer:
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1) test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
success, instrumented_source = instrument_source_module_with_async_decorators(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
# Instrument codeflash capture # Instrument codeflash capture
with progress_bar("Running tests to establish original code behavior..."): with progress_bar("Running tests to establish original code behavior..."):
try: try:
@ -1515,15 +1578,38 @@ class FunctionOptimizer:
) )
console.rule() console.rule()
with progress_bar("Running performance benchmarks..."): with progress_bar("Running performance benchmarks..."):
benchmarking_results, _ = self.run_and_parse_tests( if self.function_to_optimize.is_async:
testing_type=TestingMode.PERFORMANCE, from codeflash.code_utils.instrument_existing_tests import (
test_env=test_env, instrument_source_module_with_async_decorators,
test_files=self.test_files, )
optimization_iteration=0,
testing_time=total_looping_time, success, instrumented_source = instrument_source_module_with_async_decorators(
enable_coverage=False, self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
code_context=code_context, )
) if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
)
try:
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
)
else: else:
benchmarking_results = TestResults() benchmarking_results = TestResults()
start_time: float = time.time() start_time: float = time.time()
@ -1577,6 +1663,14 @@ class FunctionOptimizer:
console.rule() console.rule()
logger.debug(f"Total original code runtime (ns): {total_timing}") logger.debug(f"Total original code runtime (ns): {total_timing}")
async_throughput = None
if self.function_to_optimize.is_async:
async_throughput = calculate_function_throughput_from_test_results(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
console.rule()
if self.args.benchmark: if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks( replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -1590,6 +1684,7 @@ class FunctionOptimizer:
runtime=total_timing, runtime=total_timing,
coverage_results=coverage_results, coverage_results=coverage_results,
line_profile_results=line_profile_results, line_profile_results=line_profile_results,
async_throughput=async_throughput,
), ),
functions_to_remove, functions_to_remove,
) )
@ -1618,6 +1713,21 @@ class FunctionOptimizer:
candidate_helper_code = {} candidate_helper_code = {}
for module_abspath in original_helper_code: for module_abspath in original_helper_code:
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import (
instrument_source_module_with_async_decorators,
)
success, instrumented_source = instrument_source_module_with_async_decorators(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(
f"Applied async behavioral instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
)
try: try:
instrument_codeflash_capture( instrument_codeflash_capture(
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
@ -1655,14 +1765,37 @@ class FunctionOptimizer:
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
if test_framework == "pytest": if test_framework == "pytest":
candidate_benchmarking_results, _ = self.run_and_parse_tests( # For async functions, instrument at definition site for performance benchmarking
testing_type=TestingMode.PERFORMANCE, if self.function_to_optimize.is_async:
test_env=test_env, from codeflash.code_utils.instrument_existing_tests import (
test_files=self.test_files, instrument_source_module_with_async_decorators,
optimization_iteration=optimization_candidate_index, )
testing_time=total_looping_time,
enable_coverage=False, success, instrumented_source = instrument_source_module_with_async_decorators(
) self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
)
try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
loop_count = ( loop_count = (
max(all_loop_indices) max(all_loop_indices)
if ( if (
@ -1698,6 +1831,14 @@ class FunctionOptimizer:
console.rule() console.rule()
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
candidate_async_throughput = None
if self.function_to_optimize.is_async:
candidate_async_throughput = calculate_function_throughput_from_test_results(
candidate_benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
if self.args.benchmark: if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -1717,6 +1858,7 @@ class FunctionOptimizer:
else None, else None,
optimization_candidate_index=optimization_candidate_index, optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing, total_candidate_timing=total_candidate_timing,
async_throughput=candidate_async_throughput,
) )
) )
@ -1808,8 +1950,10 @@ class FunctionOptimizer:
coverage_database_file=coverage_database_file, coverage_database_file=coverage_database_file,
coverage_config_file=coverage_config_file, coverage_config_file=coverage_config_file,
) )
else: if testing_type == TestingMode.PERFORMANCE:
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file) results.perf_stdout = run_result.stdout
return results, coverage_results
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
return results, coverage_results return results, coverage_results
def submit_test_generation_tasks( def submit_test_generation_tasks(

View file

@ -134,12 +134,13 @@ class Optimizer:
project_root=self.args.project_root, project_root=self.args.project_root,
module_root=self.args.module_root, module_root=self.args.module_root,
previous_checkpoint_functions=self.args.previous_checkpoint_functions, previous_checkpoint_functions=self.args.previous_checkpoint_functions,
enable_async=getattr(self.args, "async", False),
) )
def create_function_optimizer( def create_function_optimizer(
self, self,
function_to_optimize: FunctionToOptimize, function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.FunctionDef | None = None, function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "", function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,

View file

@ -8,8 +8,9 @@ from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD, COVERAGE_THRESHOLD,
MIN_IMPROVEMENT_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD,
MIN_TESTCASE_PASSED_THRESHOLD, MIN_TESTCASE_PASSED_THRESHOLD,
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
) )
from codeflash.models.test_type import TestType from codeflash.models import models
if TYPE_CHECKING: if TYPE_CHECKING:
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
@ -25,20 +26,41 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) ->
return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns
def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float:
"""Calculate the throughput gain of an optimized code over the original code.
This value multiplied by 100 gives the percentage improvement in throughput.
For throughput, higher values are better (more executions per time period).
"""
if original_throughput == 0:
return 0.0
return (optimized_throughput - original_throughput) / original_throughput
def speedup_critic( def speedup_critic(
candidate_result: OptimizedCandidateResult, candidate_result: OptimizedCandidateResult,
original_code_runtime: int, original_code_runtime: int,
best_runtime_until_now: int | None, best_runtime_until_now: int | None,
*, *,
disable_gh_action_noise: bool = False, disable_gh_action_noise: bool = False,
original_async_throughput: int | None = None,
best_throughput_until_now: int | None = None,
) -> bool: ) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user. """Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
Ensure that the optimization is actually faster than the original code, above the noise floor. Evaluates both runtime performance and async throughput improvements.
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime. For runtime performance:
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there. - Ensures the optimization is actually faster than the original code, above the noise floor.
- The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
- The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance.
For async throughput (when available):
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
- Throughput improvements complement runtime improvements for async functions
""" """
# Runtime performance evaluation
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
if not disable_gh_action_noise and env_utils.is_ci(): if not disable_gh_action_noise and env_utils.is_ci():
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
@ -46,10 +68,31 @@ def speedup_critic(
perf_gain = performance_gain( perf_gain = performance_gain(
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
) )
if best_runtime_until_now is None: runtime_improved = perf_gain > noise_floor
# collect all optimizations with this
return bool(perf_gain > noise_floor) # Check runtime comparison with best so far
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now) runtime_is_best = best_runtime_until_now is None or candidate_result.best_test_runtime < best_runtime_until_now
throughput_improved = True # Default to True if no throughput data
throughput_is_best = True # Default to True if no throughput data
if original_async_throughput is not None and candidate_result.async_throughput is not None:
if original_async_throughput > 0:
throughput_gain_value = throughput_gain(
original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput
)
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
throughput_is_best = (
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
)
if original_async_throughput is not None and candidate_result.async_throughput is not None:
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
throughput_acceptance = throughput_improved and throughput_is_best
runtime_acceptance = runtime_improved and runtime_is_best
return throughput_acceptance or runtime_acceptance
return runtime_improved and runtime_is_best
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool: def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
@ -63,7 +106,7 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD: if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD:
return True return True
# If one or more tests passed, check if least one of them was a successful REPLAY_TEST # If one or more tests passed, check if least one of them was a successful REPLAY_TEST
return bool(pass_count >= 1 and report[TestType.REPLAY_TEST]["passed"] >= 1) return bool(pass_count >= 1 and report[models.TestType.REPLAY_TEST]["passed"] >= 1) # type: ignore # noqa: PGH003
def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool: def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool:

View file

@ -12,6 +12,7 @@ from rich.table import Table
from codeflash.code_utils.time_utils import humanize_runtime from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.lsp.helpers import is_LSP_enabled from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import BenchmarkDetail, TestResults from codeflash.models.models import BenchmarkDetail, TestResults
from codeflash.result.critic import throughput_gain
@dataclass(frozen=True, config={"arbitrary_types_allowed": True}) @dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@ -24,9 +25,28 @@ class Explanation:
function_name: str function_name: str
file_path: Path file_path: Path
benchmark_details: Optional[list[BenchmarkDetail]] = None benchmark_details: Optional[list[BenchmarkDetail]] = None
original_async_throughput: Optional[int] = None
best_async_throughput: Optional[int] = None
@property @property
def perf_improvement_line(self) -> str: def perf_improvement_line(self) -> str:
runtime_improvement = self.speedup
if (
self.original_async_throughput is not None
and self.best_async_throughput is not None
and self.original_async_throughput > 0
):
throughput_improvement = throughput_gain(
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
)
# Use throughput metrics if throughput improvement is better or runtime got worse
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
throughput_pct = f"{throughput_improvement * 100:,.0f}%"
throughput_x = f"{throughput_improvement + 1:,.2f}x"
return f"{throughput_pct} improvement ({throughput_x} faster)."
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)." return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
@property @property
@ -46,6 +66,23 @@ class Explanation:
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts # TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
original_runtime_human = humanize_runtime(self.original_runtime_ns) original_runtime_human = humanize_runtime(self.original_runtime_ns)
best_runtime_human = humanize_runtime(self.best_runtime_ns) best_runtime_human = humanize_runtime(self.best_runtime_ns)
# Determine if we're showing throughput or runtime improvements
runtime_improvement = self.speedup
is_using_throughput_metric = False
if (
self.original_async_throughput is not None
and self.best_async_throughput is not None
and self.original_async_throughput > 0
):
throughput_improvement = throughput_gain(
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
)
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
is_using_throughput_metric = True
benchmark_info = "" benchmark_info = ""
if self.benchmark_details: if self.benchmark_details:
@ -86,13 +123,18 @@ class Explanation:
console.print(table) console.print(table)
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
test_report = self.winning_behavior_test_results.get_test_pass_fail_report_by_type() if is_using_throughput_metric:
test_report_str = TestResults.report_to_string(test_report) performance_description = (
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
f"(runtime: {original_runtime_human}{best_runtime_human})\n\n"
)
else:
performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
return ( return (
f"Optimized {self.function_name} in {self.file_path}\n" f"Optimized {self.function_name} in {self.file_path}\n"
f"{self.perf_improvement_line}\n" f"{self.perf_improvement_line}\n"
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + performance_description
+ (benchmark_info if benchmark_info else "") + (benchmark_info if benchmark_info else "")
+ self.raw_explanation_message + self.raw_explanation_message
+ " \n\n" + " \n\n"
@ -101,7 +143,7 @@ class Explanation:
"" ""
if is_LSP_enabled() if is_LSP_enabled()
else "The new optimized code was tested for correctness. The results are listed below.\n" else "The new optimized code was tested for correctness. The results are listed below.\n"
+ test_report_str f"{TestResults.report_to_string(self.winning_behavior_test_results.get_test_pass_fail_report_by_type())}\n"
) )
) )

View file

@ -38,10 +38,10 @@ class CoverageUtils:
cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True) cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True)
if not database_path.stat().st_size or not database_path.exists(): if not database_path.exists() or not database_path.stat().st_size:
logger.debug(f"Coverage database {database_path} is empty or does not exist") logger.debug(f"Coverage database {database_path} is empty or does not exist")
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist") sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
return CoverageUtils.create_empty(source_code_path, function_name, code_context) return CoverageData.create_empty(source_code_path, function_name, code_context)
cov.load() cov.load()
reporter = JsonReporter(cov) reporter = JsonReporter(cov)
@ -51,8 +51,8 @@ class CoverageUtils:
reporter.report(morfs=[source_code_path.as_posix()], outfile=f) reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
except NoDataError: except NoDataError:
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}") sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
return CoverageUtils.create_empty(source_code_path, function_name, code_context) return CoverageData.create_empty(source_code_path, function_name, code_context)
with temp_json_file.open(encoding="utf-8") as f: with temp_json_file.open() as f:
original_coverage_data = json.load(f) original_coverage_data = json.load(f)
coverage_data, status = CoverageUtils._parse_coverage_file(temp_json_file, source_code_path) coverage_data, status = CoverageUtils._parse_coverage_file(temp_json_file, source_code_path)

View file

@ -40,6 +40,30 @@ matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!") matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int:
"""Calculate function throughput from TestResults by extracting performance stdout.
A completed execution is defined as having both a start tag and matching end tag from performance wrappers.
Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$!
End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######!
"""
start_matches = start_pattern.findall(test_results.perf_stdout or "")
end_matches = end_pattern.findall(test_results.perf_stdout or "")
end_matches_truncated = [end_match[:5] for end_match in end_matches]
end_matches_set = set(end_matches_truncated)
function_throughput = 0
for start_match in start_matches:
if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name:
function_throughput += 1
return function_throughput
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults: def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults() test_results = TestResults()
if not file_location.exists(): if not file_location.exists():

View file

@ -450,3 +450,26 @@ class PytestLoops:
metafunc.parametrize( metafunc.parametrize(
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope "__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
) )
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(self, item: pytest.Item) -> None:
"""Set test context environment variables before each test."""
test_module_name = item.module.__name__ if item.module else "unknown_module"
test_class_name = None
if item.cls:
test_class_name = item.cls.__name__
test_function_name = item.name
if "[" in test_function_name:
test_function_name = test_function_name.split("[", 1)[0]
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
"""Clean up test context environment variables after each test."""
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
os.environ.pop(var, None)

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build. # These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.16.7.post46.dev0+444ff121" __version__ = "0.16.7.post77.dev0+02f96e77"

View file

@ -52,8 +52,14 @@ Homepage = "https://codeflash.ai"
[project.scripts] [project.scripts]
codeflash = "codeflash.main:main" codeflash = "codeflash.main:main"
[project.optional-dependencies]
asyncio = [
"pytest-asyncio>=1.2.0",
]
[dependency-groups] [dependency-groups]
dev = [ dev = [
{include-group = "asyncio"},
"ipython>=8.12.0", "ipython>=8.12.0",
"mypy>=1.13", "mypy>=1.13",
"ruff>=0.7.0", "ruff>=0.7.0",
@ -76,6 +82,9 @@ dev = [
"uv>=0.6.2", "uv>=0.6.2",
"pre-commit>=4.2.0,<5", "pre-commit>=4.2.0,<5",
] ]
asyncio = [
"pytest-asyncio>=1.2.0",
]
[tool.hatch.build.targets.sdist] [tool.hatch.build.targets.sdist]
include = ["codeflash"] include = ["codeflash"]

View file

@ -0,0 +1,28 @@
import os
import pathlib
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="main.py",
expected_unit_tests=0,
min_improvement_x=0.1,
enable_async=True,
coverage_expectations=[
CoverageExpectation(
function_name="retry_with_backoff",
expected_coverage=100.0,
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
)
],
)
cwd = (
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e"
).resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)
if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))

View file

@ -11,7 +11,7 @@ def run_test(expected_improvement_pct: int) -> bool:
function_name="sorter", function_name="sorter",
benchmarks_root=cwd / "tests" / "pytest" / "benchmarks", benchmarks_root=cwd / "tests" / "pytest" / "benchmarks",
test_framework="pytest", test_framework="pytest",
min_improvement_x=1.0, min_improvement_x=0.70,
coverage_expectations=[ coverage_expectations=[
CoverageExpectation( CoverageExpectation(
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10] function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]

View file

@ -9,7 +9,7 @@ def run_test(expected_improvement_pct: int) -> bool:
file_path="bubble_sort.py", file_path="bubble_sort.py",
function_name="sorter", function_name="sorter",
test_framework="pytest", test_framework="pytest",
min_improvement_x=1.0, min_improvement_x=0.70,
coverage_expectations=[ coverage_expectations=[
CoverageExpectation( CoverageExpectation(
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10] function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]

View file

@ -6,7 +6,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
def run_test(expected_improvement_pct: int) -> bool: def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig( config = TestConfig(
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=3.0 file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.40
) )
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve() cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct) return run_codeflash_command(cwd, config, expected_improvement_pct)

View file

@ -27,6 +27,7 @@ class TestConfig:
trace_mode: bool = False trace_mode: bool = False
coverage_expectations: list[CoverageExpectation] = field(default_factory=list) coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
benchmarks_root: Optional[pathlib.Path] = None benchmarks_root: Optional[pathlib.Path] = None
enable_async: bool = False
def clear_directory(directory_path: str | pathlib.Path) -> None: def clear_directory(directory_path: str | pathlib.Path) -> None:
@ -134,6 +135,8 @@ def build_command(
) )
if benchmarks_root: if benchmarks_root:
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)]) base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
if config.enable_async:
base_command.append("--async")
return base_command return base_command

View file

@ -3,6 +3,10 @@ from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
import tempfile
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
import libcst as cst
from codeflash.models.models import FunctionParent
def test_add_needed_imports_from_module0() -> None: def test_add_needed_imports_from_module0() -> None:
src_module = '''import ast src_module = '''import ast
@ -349,3 +353,141 @@ class DbtAdapter(BaseAdapter):
project_root_path=Path(__file__).resolve().parent.resolve(), project_root_path=Path(__file__).resolve().parent.resolve(),
) )
assert new_code == expected assert new_code == expected
def test_resolve_star_import_with_all_defined():
"""Test resolve_star_import when __all__ is explicitly defined."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
# Create a test module with __all__ definition
test_module.write_text('''
__all__ = ['public_function', 'PublicClass']
def public_function():
pass
def _private_function():
pass
class PublicClass:
pass
class AnotherPublicClass:
"""Not in __all__ so should be excluded."""
pass
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_function', 'PublicClass'}
assert symbols == expected_symbols
def test_resolve_star_import_without_all_defined():
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
# Create a test module without __all__ definition
test_module.write_text('''
def public_func():
pass
def _private_func():
pass
class PublicClass:
pass
PUBLIC_VAR = 42
_private_var = 'secret'
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
assert symbols == expected_symbols
def test_resolve_star_import_nonexistent_module():
"""Test resolve_star_import with non-existent module - should return empty set."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
symbols = resolve_star_import('nonexistent_module', project_root)
assert symbols == set()
def test_dotted_import_collector_skips_star_imports():
"""Test that DottedImportCollector correctly skips star imports."""
code_with_star_import = '''
from typing import *
from pathlib import Path
from collections import defaultdict
import os
'''
module = cst.parse_module(code_with_star_import)
collector = DottedImportCollector()
module.visit(collector)
# Should collect regular imports but skip the star import
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
assert collector.imports == expected_imports
def test_add_needed_imports_with_star_import_resolution():
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Create a source module that exports symbols
src_module = project_root / 'source_module.py'
src_module.write_text('''
__all__ = ['UtilFunction', 'HelperClass']
def UtilFunction():
pass
class HelperClass:
pass
''')
# Create source code that uses star import
src_code = '''
from source_module import *
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
# Destination code that needs the imports resolved
dst_code = '''
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
src_path = project_root / 'src.py'
dst_path = project_root / 'dst.py'
src_path.write_text(src_code)
result = add_needed_imports_from_module(
src_code, dst_code, src_path, dst_path, project_root
)
# The result should have individual imports instead of star import
expected_result = '''from source_module import HelperClass, UtilFunction
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
assert result == expected_result

View file

@ -1902,4 +1902,210 @@ def test_bubble_sort(input, expected_output):
# Check that comments were added # Check that comments were added
modified_source = result.generated_tests[0].generated_original_test_source modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == expected assert modified_source == expected
def test_async_basic_runtime_comment_addition(self, test_config):
"""Test basic functionality of adding runtime comments to async test functions."""
os.chdir(test_config.project_root_path)
test_source = """async def test_async_bubble_sort():
codeflash_output = await async_bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py",
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_invocation = self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0') # 500μs
optimized_invocation = self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0') # 300μs
original_test_results.add(original_invocation)
optimized_test_results.add(optimized_invocation)
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
def test_async_multiple_test_functions(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """async def test_async_bubble_sort():
codeflash_output = await async_quick_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
async def test_async_quick_sort():
codeflash_output = await async_quick_sort([5, 2, 8])
assert codeflash_output == [2, 5, 8]
def helper_function():
return "not a test"
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_async_quick_sort", 800_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_quick_sort", 600_000, iteration_id='0'))
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "# 800μs -> 600μs" in modified_source
assert (
"helper_function():" in modified_source
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
)
def test_async_class_method(self, test_config):
os.chdir(test_config.project_root_path)
test_source = '''class TestAsyncClass:
async def test_async_function(self):
codeflash_output = await some_async_function()
assert codeflash_output == expected
'''
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
invocation_id = InvocationId(
test_module_path="tests.test_module__unit_test_0",
test_class_name="TestAsyncClass",
test_function_name="test_async_function",
function_getting_tested="some_async_function",
iteration_id="0",
)
original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds
optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
expected_source = '''class TestAsyncClass:
async def test_async_function(self):
codeflash_output = await some_async_function() # 2.00s -> 1.00s (100% faster)
assert codeflash_output == expected
'''
assert len(result.generated_tests) == 1
assert result.generated_tests[0].generated_original_test_source == expected_source
def test_async_mixed_sync_and_async_functions(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """def test_sync_function():
codeflash_output = sync_function([1, 2, 3])
assert codeflash_output == [1, 2, 3]
async def test_async_function():
codeflash_output = await async_function([4, 5, 6])
assert codeflash_output == [4, 5, 6]
def test_another_sync():
result = another_sync_func()
assert result is True
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
# Add test invocations for all test functions
original_test_results.add(self.create_test_invocation("test_sync_function", 400_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_async_function", 600_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_another_sync", 200_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_sync_function", 200_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_function", 300_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_another_sync", 100_000, iteration_id='0'))
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 400μs -> 200μs" in modified_source
assert "# 600μs -> 300μs" in modified_source
assert "# 200μs -> 100μs" in modified_source
assert "async def test_async_function():" in modified_source
assert "await async_function([4, 5, 6])" in modified_source
def test_async_complex_await_patterns(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """async def test_complex_async():
# Multiple await calls
result1 = await async_func1()
codeflash_output = await async_func2(result1)
result3 = await async_func3(codeflash_output)
assert result3 == expected
# Await in context manager
async with async_context() as ctx:
final_result = await ctx.process()
assert final_result is not None
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_complex_async", 750_000, iteration_id='1')) # 750μs
optimized_test_results.add(self.create_test_invocation("test_complex_async", 450_000, iteration_id='1')) # 450μs
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 750μs -> 450μs" in modified_source

View file

@ -0,0 +1,337 @@
import tempfile
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import (
find_all_functions_in_file,
get_functions_to_optimize,
inspect_top_level_functions_or_methods,
)
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
def test_async_function_detection(temp_dir):
async_function = """
async def async_function_with_return():
await some_async_operation()
return 42
async def async_function_without_return():
await some_async_operation()
print("No return")
def regular_function():
return 10
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_function)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_function_with_return" in function_names
assert "regular_function" in function_names
assert "async_function_without_return" not in function_names
def test_async_method_in_class(temp_dir):
code_with_async_method = """
class AsyncClass:
async def async_method(self):
await self.do_something()
return "result"
async def async_method_no_return(self):
await self.do_something()
pass
def sync_method(self):
return "sync result"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code_with_async_method)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
function_names = [fn.function_name for fn in found_functions]
qualified_names = [fn.qualified_name for fn in found_functions]
assert "async_method" in function_names
assert "AsyncClass.async_method" in qualified_names
assert "sync_method" in function_names
assert "AsyncClass.sync_method" in qualified_names
assert "async_method_no_return" not in function_names
def test_nested_async_functions(temp_dir):
nested_async = """
async def outer_async():
async def inner_async():
return "inner"
result = await inner_async()
return result
def outer_sync():
async def inner_async():
return "inner from sync"
return inner_async
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(nested_async)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "outer_async" in function_names
assert "outer_sync" in function_names
assert "inner_async" not in function_names
def test_async_staticmethod_and_classmethod(temp_dir):
async_decorators = """
class MyClass:
@staticmethod
async def async_static_method():
await some_operation()
return "static result"
@classmethod
async def async_class_method(cls):
await cls.some_operation()
return "class result"
@property
async def async_property(self):
return await self.get_value()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_decorators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_static_method" in function_names
assert "async_class_method" in function_names
assert "async_property" not in function_names
def test_async_generator_functions(temp_dir):
async_generators = """
async def async_generator_with_return():
for i in range(10):
yield i
return "done"
async def async_generator_no_return():
for i in range(10):
yield i
async def regular_async_with_return():
result = await compute()
return result
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_generators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_generator_with_return" in function_names
assert "regular_async_with_return" in function_names
assert "async_generator_no_return" not in function_names
def test_inspect_async_top_level_functions(temp_dir):
code = """
async def top_level_async():
return 42
class AsyncContainer:
async def async_method(self):
async def nested_async():
return 1
return await nested_async()
@staticmethod
async def async_static():
return "static"
@classmethod
async def async_classmethod(cls):
return "classmethod"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code)
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
assert not result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_staticmethod
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_classmethod
def test_get_functions_to_optimize_with_async(temp_dir):
mixed_code = """
async def async_func_one():
return await operation_one()
def sync_func_one():
return operation_one()
async def async_func_two():
print("no return")
class MixedClass:
async def async_method(self):
return await self.operation()
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=file_path,
only_get_this_function=None,
test_cfg=test_config,
ignore_paths=[],
project_root=file_path.parent,
module_root=file_path.parent,
enable_async=True,
)
assert functions_count == 4
function_names = [fn.function_name for fn in functions[file_path]]
assert "async_func_one" in function_names
assert "sync_func_one" in function_names
assert "async_method" in function_names
assert "sync_method" in function_names
assert "async_func_two" not in function_names
def test_no_async_functions_finding(temp_dir):
mixed_code = """
async def async_func_one():
return await operation_one()
def sync_func_one():
return operation_one()
async def async_func_two():
print("no return")
class MixedClass:
async def async_method(self):
return await self.operation()
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=file_path,
only_get_this_function=None,
test_cfg=test_config,
ignore_paths=[],
project_root=file_path.parent,
module_root=file_path.parent,
enable_async=False,
)
assert functions_count == 2
function_names = [fn.function_name for fn in functions[file_path]]
assert "sync_func_one" in function_names
assert "sync_method" in function_names
assert "async_func_one" not in function_names
assert "async_method" not in function_names
def test_async_function_parents(temp_dir):
complex_structure = """
class OuterClass:
async def outer_method(self):
return 1
class InnerClass:
async def inner_method(self):
return 2
async def module_level_async():
class LocalClass:
async def local_method(self):
return 3
return LocalClass()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(complex_structure)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
for fn in found_functions:
if fn.function_name == "outer_method":
assert len(fn.parents) == 1
assert fn.parents[0].name == "OuterClass"
assert fn.qualified_name == "OuterClass.outer_method"
elif fn.function_name == "inner_method":
assert len(fn.parents) == 2
assert fn.parents[0].name == "OuterClass"
assert fn.parents[1].name == "InnerClass"
elif fn.function_name == "module_level_async":
assert len(fn.parents) == 0
assert fn.qualified_name == "module_level_async"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,285 @@
from __future__ import annotations
import asyncio
import os
import sqlite3
import tempfile
from pathlib import Path
import pytest
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import (
codeflash_behavior_async,
codeflash_performance_async,
)
from codeflash.verification.codeflash_capture import VerificationType
class TestAsyncWrapperSQLiteValidation:
@pytest.fixture
def test_env_setup(self, request):
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "0",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.fixture
def temp_db_path(self, test_env_setup):
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
yield db_path
if db_path.exists():
db_path.unlink()
@pytest.mark.asyncio
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def simple_async_add(a: int, b: int) -> int:
await asyncio.sleep(0.001)
return a + b
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
result = await simple_async_add(5, 3)
assert result == 8
assert temp_db_path.exists()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
assert cur.fetchone() is not None
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1
row = rows[0]
(test_module_path, test_class_name, test_function_name, function_getting_tested,
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
assert test_module_path == __name__
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
assert test_function_name == "test_behavior_async_basic_function"
assert function_getting_tested == "simple_async_add"
assert loop_index == 1
# Line ID will be the actual line number from the source code, not a simple counter
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
assert runtime > 0
assert verification_type == VerificationType.FUNCTION_CALL.value
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, return_val = unpickled_data
assert args == (5, 3)
assert kwargs == {}
assert return_val == 8
con.close()
@pytest.mark.asyncio
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_divide(a: int, b: int) -> float:
await asyncio.sleep(0.001)
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
result = await async_divide(10, 2)
assert result == 5.0
with pytest.raises(ValueError, match="Cannot divide by zero"):
await async_divide(10, 0)
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 2
success_row = rows[0]
success_data = pickle.loads(success_row[7]) # return_value_blob
args, kwargs, return_val = success_data
assert args == (10, 2)
assert return_val == 5.0
# Check exception record
exception_row = rows[1]
exception_data = pickle.loads(exception_row[7]) # return_value_blob
assert isinstance(exception_data, ValueError)
assert str(exception_data) == "Cannot divide by zero"
con.close()
@pytest.mark.asyncio
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
"""Test performance async decorator doesn't store to database."""
@codeflash_performance_async
async def async_multiply(a: int, b: int) -> int:
"""Async function for performance testing."""
await asyncio.sleep(0.002)
return a * b
result = await async_multiply(4, 7)
assert result == 28
assert not temp_db_path.exists()
captured = capsys.readouterr()
output_lines = captured.out.strip().split('\n')
assert len([line for line in output_lines if "!$######" in line]) == 1
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
assert "async_multiply" in closing_tag
timing_part = closing_tag.split(":")[-1].replace("######!", "")
timing_value = int(timing_part)
assert timing_value > 0 # Should have positive timing
@pytest.mark.asyncio
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_increment(value: int) -> int:
await asyncio.sleep(0.001)
return value + 1
# Call the function multiple times
results = []
for i in range(3):
result = await async_increment(i)
results.append(result)
assert results == [1, 2, 3]
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 3
actual_ids = [row[0] for row in rows]
assert len(actual_ids) == 3
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
assert actual_ids == expected_pattern
for i, (_, return_value_blob) in enumerate(rows):
args, kwargs, return_val = pickle.loads(return_value_blob)
assert args == (i,)
assert return_val == i + 1
con.close()
@pytest.mark.asyncio
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def complex_async_func(
pos_arg: str,
*args: int,
keyword_arg: str = "default",
**kwargs: str
) -> dict:
await asyncio.sleep(0.001)
return {
"pos_arg": pos_arg,
"args": args,
"keyword_arg": keyword_arg,
"kwargs": kwargs,
}
result = await complex_async_func(
"hello",
1, 2, 3,
keyword_arg="custom",
extra1="value1",
extra2="value2"
)
expected_result = {
"pos_arg": "hello",
"args": (1, 2, 3),
"keyword_arg": "custom",
"kwargs": {"extra1": "value1", "extra2": "value2"}
}
assert result == expected_result
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM test_results")
row = cur.fetchone()
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
assert stored_args == ("hello", 1, 2, 3)
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
assert stored_result == expected_result
con.close()
@pytest.mark.asyncio
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def schema_test_func() -> str:
return "schema_test"
await schema_test_func()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("PRAGMA table_info(test_results)")
columns = cur.fetchall()
expected_columns = [
(0, 'test_module_path', 'TEXT', 0, None, 0),
(1, 'test_class_name', 'TEXT', 0, None, 0),
(2, 'test_function_name', 'TEXT', 0, None, 0),
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
(4, 'loop_index', 'INTEGER', 0, None, 0),
(5, 'iteration_id', 'TEXT', 0, None, 0),
(6, 'runtime', 'INTEGER', 0, None, 0),
(7, 'return_value', 'BLOB', 0, None, 0),
(8, 'verification_type', 'TEXT', 0, None, 0)
]
assert columns == expected_columns
con.close()

View file

@ -12,7 +12,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
class HelperClass: class HelperClass:
@ -1793,9 +1793,10 @@ def get_system_details():
# Set up the optimizer # Set up the optimizer
file_path = main_file_path.resolve() file_path = main_file_path.resolve()
project_root = package_dir.resolve()
opt = Optimizer( opt = Optimizer(
Namespace( Namespace(
project_root=package_dir.resolve(), project_root=project_root,
disable_telemetry=True, disable_telemetry=True,
tests_root="tests", tests_root="tests",
test_framework="pytest", test_framework="pytest",
@ -1819,6 +1820,8 @@ def get_system_details():
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context hashing_context = code_ctx.hashing_code_context
# The expected contexts # The expected contexts
# Resolve both paths to handle symlink issues on macOS
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f""" expected_read_write_context = f"""
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())} ```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
import utility_module import utility_module
@ -2038,9 +2041,10 @@ def get_system_details():
# Set up the optimizer # Set up the optimizer
file_path = main_file_path.resolve() file_path = main_file_path.resolve()
project_root = package_dir.resolve()
opt = Optimizer( opt = Optimizer(
Namespace( Namespace(
project_root=package_dir.resolve(), project_root=project_root,
disable_telemetry=True, disable_telemetry=True,
tests_root="tests", tests_root="tests",
test_framework="pytest", test_framework="pytest",
@ -2063,6 +2067,7 @@ def get_system_details():
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# The expected contexts # The expected contexts
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f""" expected_read_write_context = f"""
```python:utility_module.py ```python:utility_module.py
# Function that will be used in the main code # Function that will be used in the main code
@ -2463,3 +2468,148 @@ def test_circular_deps():
assert "import ApiClient" not in new_code, "Error: Circular dependency found" assert "import ApiClient" not in new_code, "Error: Circular dependency found"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
def test_global_assignment_collector_with_async_function():
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
import libcst as cst
source_code = """
# Global assignment
GLOBAL_VAR = "global_value"
OTHER_GLOBAL = 42
async def async_function():
# This should not be collected (inside async function)
local_var = "local_value"
INNER_ASSIGNMENT = "should_not_be_global"
return local_var
# Another global assignment
ANOTHER_GLOBAL = "another_global"
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should collect global assignments but not the ones inside async function
assert len(collector.assignments) == 3
assert "GLOBAL_VAR" in collector.assignments
assert "OTHER_GLOBAL" in collector.assignments
assert "ANOTHER_GLOBAL" in collector.assignments
# Should not collect assignments from inside async function
assert "local_var" not in collector.assignments
assert "INNER_ASSIGNMENT" not in collector.assignments
# Verify assignment order
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
assert collector.assignment_order == expected_order
def test_global_assignment_collector_nested_async_functions():
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
import libcst as cst
source_code = """
# Global assignment
CONFIG = {"key": "value"}
def sync_function():
# Inside sync function - should not be collected
sync_local = "sync"
async def nested_async():
# Inside nested async function - should not be collected
nested_var = "nested"
return nested_var
return sync_local
async def async_function():
# Inside async function - should not be collected
async_local = "async"
def nested_sync():
# Inside nested function - should not be collected
deeply_nested = "deep"
return deeply_nested
return async_local
# Another global assignment
FINAL_GLOBAL = "final"
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should only collect global-level assignments
assert len(collector.assignments) == 2
assert "CONFIG" in collector.assignments
assert "FINAL_GLOBAL" in collector.assignments
# Should not collect any assignments from inside functions
assert "sync_local" not in collector.assignments
assert "nested_var" not in collector.assignments
assert "async_local" not in collector.assignments
assert "deeply_nested" not in collector.assignments
def test_global_assignment_collector_mixed_async_sync_with_classes():
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
import libcst as cst
source_code = """
# Global assignments
GLOBAL_CONSTANT = "constant"
class TestClass:
# Class-level assignment - should not be collected
class_var = "class_value"
def sync_method(self):
# Method assignment - should not be collected
method_var = "method"
return method_var
async def async_method(self):
# Async method assignment - should not be collected
async_method_var = "async_method"
return async_method_var
def sync_function():
# Function assignment - should not be collected
func_var = "function"
return func_var
async def async_function():
# Async function assignment - should not be collected
async_func_var = "async_function"
return async_func_var
# More global assignments
ANOTHER_CONSTANT = 100
FINAL_ASSIGNMENT = {"data": "value"}
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should only collect global-level assignments
assert len(collector.assignments) == 3
assert "GLOBAL_CONSTANT" in collector.assignments
assert "ANOTHER_CONSTANT" in collector.assignments
assert "FINAL_ASSIGNMENT" in collector.assignments
# Should not collect assignments from inside any scoped blocks
assert "class_var" not in collector.assignments
assert "method_var" not in collector.assignments
assert "async_method_var" not in collector.assignments
assert "func_var" not in collector.assignments
assert "async_func_var" not in collector.assignments
# Verify correct order
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
assert collector.assignment_order == expected_order

View file

@ -12,6 +12,7 @@ from codeflash.code_utils.code_replacer import (
is_zero_diff, is_zero_diff,
replace_functions_and_add_imports, replace_functions_and_add_imports,
replace_functions_in_file, replace_functions_in_file,
OptimFunctionCollector,
) )
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
@ -3448,156 +3449,173 @@ def hydrate_input_text_actions_with_field_names(
assert new_code == expected assert new_code == expected
def test_duplicate_global_assignments_when_reverting_helpers():
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
original_code = '''"""Chunking objects not specific to a particular chunking strategy.""" # OptimFunctionCollector async function tests
from __future__ import annotations def test_optim_function_collector_with_async_functions():
import collections """Test OptimFunctionCollector correctly collects async functions."""
import copy import libcst as cst
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex source_code = """
from typing_extensions import Self, TypeAlias def sync_function():
from unstructured.utils import lazyproperty return "sync"
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# -- all detectors need to be called to update state and avoid double counting
# -- boundaries that happen to coincide, like Table and new section on same element.
# -- Using `any()` would short-circuit on first True.
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
return any(semantic_boundaries)
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
from __future__ import annotations
from typing import Iterable
from unstructured.documents.elements import Element
from unstructured.utils import lazyproperty
class PreChunker:
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
```
'''
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file) async def async_function():
test_config = TestConfig( return "async"
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir, class TestClass:
project_root_path=root_dir, def sync_method(self):
test_framework="pytest", return "sync_method"
pytest_cmd="pytest",
async def async_method(self):
return "async_method"
"""
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
preexisting_objects=None
) )
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) tree.visit(collector)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
# Should collect both sync and async functions
assert len(collector.modified_functions) == 4
assert (None, "sync_function") in collector.modified_functions
assert (None, "async_function") in collector.modified_functions
assert ("TestClass", "sync_method") in collector.modified_functions
assert ("TestClass", "async_method") in collector.modified_functions
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args() def test_optim_function_collector_new_async_functions():
func_optimizer.replace_function_and_helpers_with_optimized_code( """Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code import libcst as cst
source_code = """
def existing_function():
return "existing"
async def new_async_function():
return "new_async"
def new_sync_function():
return "new_sync"
class ExistingClass:
async def new_class_async_method(self):
return "new_class_async"
"""
# Only existing_function is in preexisting objects
preexisting_objects = {("existing_function", ())}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=set(), # Not looking for specific functions
preexisting_objects=preexisting_objects
) )
tree.visit(collector)
# Should identify new functions (both sync and async)
assert len(collector.new_functions) == 2
function_names = [func.name.value for func in collector.new_functions]
assert "new_async_function" in function_names
assert "new_sync_function" in function_names
# Should identify new class methods
assert "ExistingClass" in collector.new_class_functions
assert len(collector.new_class_functions["ExistingClass"]) == 1
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
new_code = main_file.read_text(encoding="utf-8") def test_optim_function_collector_mixed_scenarios():
main_file.unlink(missing_ok=True) """Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
import libcst as cst
source_code = """
# Global functions
def global_sync():
pass
expected = '''"""Chunking objects not specific to a particular chunking strategy.""" async def global_async():
from __future__ import annotations pass
import collections
import copy class ParentClass:
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast def __init__(self):
import regex pass
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty def sync_method(self):
from unstructured.documents.elements import Element pass
# ================================================================================================
# MODEL async def async_method(self):
# ================================================================================================ pass
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================ class ChildClass:
# PRE-CHUNKER async def child_async_method(self):
# ================================================================================================ pass
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow. def child_sync_method(self):
The pre-chunker's responsibilities are: pass
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on """
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a # Looking for specific functions
semantic boundary when `multipage_sections` is `False`. function_names = {
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit (None, "global_sync"),
into sections as big as possible without exceeding the chunk window size. (None, "global_async"),
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section ("ParentClass", "sync_method"),
and only produce a section that exceeds the chunk window size when there is a single element ("ParentClass", "async_method"),
with text longer than that window. ("ChildClass", "child_async_method")
A Table element is placed into a section by itself. CheckBox elements are dropped. }
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation. tree = cst.parse_module(source_code)
""" collector = OptimFunctionCollector(
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions): function_names=function_names,
self._elements = elements preexisting_objects=None
self._opts = opts )
@lazyproperty tree.visit(collector)
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks.""" # Should collect all specified functions (mix of sync and async)
return self._opts.boundary_predicates assert len(collector.modified_functions) == 5
def _is_in_new_semantic_unit(self, element: Element) -> bool: assert (None, "global_sync") in collector.modified_functions
"""True when `element` begins a new semantic unit such as a section or page.""" assert (None, "global_async") in collector.modified_functions
# Use generator expression for lower memory usage and avoid building intermediate list assert ("ParentClass", "sync_method") in collector.modified_functions
for pred in self._boundary_predicates: assert ("ParentClass", "async_method") in collector.modified_functions
if pred(element): assert ("ChildClass", "child_async_method") in collector.modified_functions
return True
return False # Should collect __init__ method
assert "ParentClass" in collector.modified_init_functions
def test_is_zero_diff_async_sleep():
original_code = '''
import time
async def task():
time.sleep(1)
return "done"
''' '''
assert new_code == expected optimized_code = '''
import asyncio
async def task():
await asyncio.sleep(1)
return "done"
'''
assert not is_zero_diff(original_code, optimized_code)
def test_is_zero_diff_with_equivalent_code():
original_code = '''
import asyncio
async def task():
await asyncio.sleep(1)
return "done"
'''
optimized_code = '''
import asyncio
async def task():
"""A task that does something."""
await asyncio.sleep(1)
return "done"
'''
assert is_zero_diff(original_code, optimized_code)

View file

@ -17,10 +17,10 @@ from codeflash.code_utils.code_utils import (
is_class_defined_in_file, is_class_defined_in_file,
module_name_from_file_path, module_name_from_file_path,
path_belongs_to_site_packages, path_belongs_to_site_packages,
has_any_async_functions, validate_python_code,
) )
from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
@pytest.fixture @pytest.fixture
@ -368,6 +368,86 @@ def my_function():
assert is_class_defined_in_file("MyClass", test_file) is False assert is_class_defined_in_file("MyClass", test_file) is False
@pytest.fixture
def mock_code_context():
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
from unittest.mock import MagicMock
from codeflash.models.models import CodeOptimizationContext
context = MagicMock(spec=CodeOptimizationContext)
context.preexisting_objects = []
return context
def test_extract_dependent_function_sync_and_async(mock_code_context):
"""Test extract_dependent_function with both sync and async functions."""
# Test sync function extraction
mock_code_context.testgen_context_code = """
def main_function():
pass
def helper_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
# Test async function extraction
mock_code_context.testgen_context_code = """
def main_function():
pass
async def async_helper_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
def test_extract_dependent_function_edge_cases(mock_code_context):
"""Test extract_dependent_function edge cases."""
# No dependent functions
mock_code_context.testgen_context_code = """
def main_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) is False
# Multiple dependent functions
mock_code_context.testgen_context_code = """
def main_function():
pass
def helper1():
pass
async def helper2():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) is False
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
"""Test extract_dependent_function with mixed sync/async scenarios."""
# Async main with sync helper
mock_code_context.testgen_context_code = """
async def async_main():
pass
def sync_helper():
pass
"""
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
# Only async functions
mock_code_context.testgen_context_code = """
async def async_main():
pass
async def async_helper():
pass
"""
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
def test_is_class_defined_in_file_with_non_existing_file() -> None: def test_is_class_defined_in_file_with_non_existing_file() -> None:
non_existing_file = Path("/non/existing/file.py") non_existing_file = Path("/non/existing/file.py")
@ -505,25 +585,41 @@ def test_Grammar_copy():
assert cleaned_code == expected_cleaned_code.strip() assert cleaned_code == expected_cleaned_code.strip()
def test_has_any_async_functions_with_async_code() -> None: def test_validate_python_code_valid() -> None:
code = "def hello():\n return 'world'"
result = validate_python_code(code)
assert result == code
def test_validate_python_code_invalid() -> None:
code = "def hello(:\n return 'world'"
with pytest.raises(ValueError, match="Invalid Python code"):
validate_python_code(code)
def test_validate_python_code_empty() -> None:
code = ""
result = validate_python_code(code)
assert result == code
def test_validate_python_code_complex_invalid() -> None:
code = "if True\n print('missing colon')"
with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"):
validate_python_code(code)
def test_validate_python_code_valid_complex() -> None:
code = """ code = """
def normal_function(): def calculate(a, b):
pass if a > b:
return a + b
async def async_function(): else:
pass return a * b
class MyClass:
def __init__(self):
self.value = 42
""" """
result = has_any_async_functions(code) result = validate_python_code(code)
assert result is True assert result == code
def test_has_any_async_functions_without_async_code() -> None:
code = """
def normal_function():
pass
def another_function():
pass
"""
result = has_any_async_functions(code)
assert result is False

View file

@ -14,7 +14,13 @@ from codeflash.models.models import (
TestResults, TestResults,
TestType, TestType,
) )
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.critic import (
coverage_critic,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
throughput_gain,
)
def test_performance_gain() -> None: def test_performance_gain() -> None:
@ -429,3 +435,159 @@ def test_coverage_critic() -> None:
) )
assert coverage_critic(unittest_coverage, "unittest") is True assert coverage_critic(unittest_coverage, "unittest") is True
def test_throughput_gain() -> None:
"""Test throughput_gain calculation."""
# Test basic throughput improvement
assert throughput_gain(original_throughput=100, optimized_throughput=150) == 0.5 # 50% improvement
# Test no improvement
assert throughput_gain(original_throughput=100, optimized_throughput=100) == 0.0
# Test regression
assert throughput_gain(original_throughput=100, optimized_throughput=80) == -0.2 # 20% regression
# Test zero original throughput (edge case)
assert throughput_gain(original_throughput=0, optimized_throughput=50) == 0.0
# Test large improvement
assert throughput_gain(original_throughput=50, optimized_throughput=200) == 3.0 # 300% improvement
def test_speedup_critic_with_async_throughput() -> None:
"""Test speedup_critic with async throughput evaluation."""
original_code_runtime = 10000 # 10 microseconds
original_async_throughput = 100
# Test case 1: Both runtime and throughput improve significantly
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=120, # 20% throughput improvement
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=105, # Only 5% throughput improvement (below 10% threshold)
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=9800, # Only 2% runtime improvement (below 5% threshold)
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=9800,
async_throughput=120, # 20% throughput improvement
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 4: No throughput data - should fall back to runtime-only evaluation
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=None, # No throughput data
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=None, # No original throughput data
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 5: Test best_throughput_until_now comparison
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=115, # 15% throughput improvement
)
# Should pass when no best throughput yet
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Should fail when there's a better throughput already
assert not speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=7000, # Better runtime already exists
original_async_throughput=original_async_throughput,
best_throughput_until_now=120, # Better throughput already exists
disable_gh_action_noise=True
)
# Test case 6: Zero original throughput (edge case)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=50,
)
# Should pass when original throughput is 0 (throughput evaluation skipped)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=0, # Zero original throughput
best_throughput_until_now=None,
disable_gh_action_noise=True
)

View file

@ -0,0 +1,793 @@
import tempfile
from pathlib import Path
import uuid
import os
import pytest
from codeflash.code_utils.instrument_existing_tests import (
add_async_decorator_to_function,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
# @pytest.fixture
# def unique_test_iteration():
# """Provide a unique test iteration ID and clean up database after test."""
# # Generate unique iteration ID
# iteration_id = str(uuid.uuid4())[:8]
# # Store original environment variable
# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION")
# # Set unique iteration for this test
# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id
# try:
# yield iteration_id
# finally:
# # Cleanup: restore original environment and delete database file
# if original_iteration is not None:
# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration
# elif "CODEFLASH_TEST_ITERATION" in os.environ:
# del os.environ["CODEFLASH_TEST_ITERATION"]
# # Clean up database file
# try:
# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite"))
# if db_path.exists():
# db_path.unlink()
# except Exception:
# pass # Ignore cleanup errors
def test_async_decorator_application_behavior_mode():
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
def test_async_decorator_application_performance_mode():
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_performance_async
@codeflash_performance_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
def test_async_class_method_decorator_application():
async_class_code = '''
import asyncio
class Calculator:
"""Test class with async methods."""
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class Calculator:
"""Test class with async methods."""
@codeflash_behavior_async
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
func = FunctionToOptimize(
function_name="async_method",
file_path=Path("test_async.py"),
parents=[{"name": "Calculator", "type": "ClassDef"}],
is_async=True,
)
modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
def test_async_decorator_no_duplicate_application():
already_decorated_code = '''
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
import asyncio
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
expected_reformatted_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR)
assert not decorator_added
assert modified_code.strip() == expected_reformatted_code.strip()
def test_inject_profiling_async_function_behavior_mode(temp_dir):
source_module_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
async_test_code = '''
import asyncio
import pytest
from my_module import async_function
@pytest.mark.asyncio
async def test_async_function():
"""Test async function behavior."""
result = await async_function(5, 3)
assert result == 15
result2 = await async_function(2, 4)
assert result2 == 8
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_test_code)
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.BEHAVIOR
)
assert source_success is True
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
source_file.write_text(instrumented_source)
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
)
# For async functions, once source is decorated, test injection should fail
# This is expected behavior - async instrumentation happens at the decorator level
assert success is False
assert instrumented_test_code is None
def test_inject_profiling_async_function_performance_mode(temp_dir):
source_module_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
# Create the test file
async_test_code = '''
import asyncio
import pytest
from my_module import async_function
@pytest.mark.asyncio
async def test_async_function():
"""Test async function performance."""
result = await async_function(5, 3)
assert result == 15
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_test_code)
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.PERFORMANCE
)
assert source_success is True
assert instrumented_source is not None
assert "@codeflash_performance_async" in instrumented_source
# Check for the import with line continuation formatting
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_performance_async" in instrumented_source
# Write the instrumented source back
source_file.write_text(instrumented_source)
# Now test the full pipeline with source module path
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE
)
# For async functions, once source is decorated, test injection should fail
# This is expected behavior - async instrumentation happens at the decorator level
assert success is False
assert instrumented_test_code is None
def test_mixed_sync_async_instrumentation(temp_dir):
source_module_code = '''
import asyncio
def sync_function(x: int, y: int) -> int:
"""Regular sync function."""
return x * y
async def async_function(x: int, y: int) -> int:
"""Simple async function."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
mixed_test_code = '''
import asyncio
import pytest
from my_module import sync_function, async_function
@pytest.mark.asyncio
async def test_mixed_functions():
"""Test both sync and async functions."""
sync_result = sync_function(10, 5)
assert sync_result == 50
async_result = await async_function(3, 4)
assert async_result == 12
'''
test_file = temp_dir / "test_mixed.py"
test_file.write_text(mixed_test_code)
async_func = FunctionToOptimize(
function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True
)
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, async_func, TestingMode.BEHAVIOR
)
assert source_success
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
# Sync function should remain unchanged
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
# Write instrumented source back
source_file.write_text(instrumented_source)
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file,
[CodePosition(8, 18), CodePosition(11, 19)],
async_func,
temp_dir,
"pytest",
mode=TestingMode.BEHAVIOR,
)
# Async functions should not be instrumented at the test level
assert not success
assert instrumented_test_code is None
def test_async_function_qualified_name_handling():
nested_async_code = '''
import asyncio
class OuterClass:
class InnerClass:
async def nested_async_method(self, x: int) -> int:
"""Nested async method."""
await asyncio.sleep(0.001)
return x * 2
'''
func = FunctionToOptimize(
function_name="nested_async_method",
file_path=Path("test_nested.py"),
parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}],
is_async=True,
)
modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR)
expected_output = (
"""import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class OuterClass:
class InnerClass:
@codeflash_behavior_async
async def nested_async_method(self, x: int) -> int:
\"\"\"Nested async method.\"\"\"
await asyncio.sleep(0.001)
return x * 2
"""
)
assert modified_code.strip() == expected_output.strip()
def test_async_decorator_with_existing_decorators():
"""Test async decorator application when function already has other decorators."""
decorated_async_code = '''
import asyncio
from functools import wraps
def my_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
@my_decorator
async def async_function(x: int, y: int) -> int:
"""Async function with existing decorator."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR)
assert decorator_added
# Should add codeflash decorator above existing decorators
assert "@codeflash_behavior_async" in modified_code
assert "@my_decorator" in modified_code
# Codeflash decorator should come first
codeflash_pos = modified_code.find("@codeflash_behavior_async")
my_decorator_pos = modified_code.find("@my_decorator")
assert codeflash_pos < my_decorator_pos
def test_sync_function_not_affected_by_async_logic():
sync_function_code = '''
def sync_function(x: int, y: int) -> int:
"""Regular sync function."""
return x + y
'''
sync_func = FunctionToOptimize(
function_name="sync_function",
file_path=Path("test_sync.py"),
parents=[],
is_async=False,
)
modified_code, decorator_added = add_async_decorator_to_function(
sync_function_code, sync_func, TestingMode.BEHAVIOR
)
assert not decorator_added
assert modified_code == sync_function_code
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
source_module_code = '''
import asyncio
async def async_sorter(items):
"""Simple async sorter for testing."""
await asyncio.sleep(0.001)
return sorted(items)
'''
source_file = temp_dir / "async_sorter.py"
source_file.write_text(source_module_code)
test_code_multiple_calls = """
import asyncio
import pytest
from async_sorter import async_sorter
@pytest.mark.asyncio
async def test_single_call():
result = await async_sorter([42])
assert result == [42]
@pytest.mark.asyncio
async def test_multiple_calls():
result1 = await async_sorter([3, 1, 2])
result2 = await async_sorter([5, 4])
result3 = await async_sorter([9, 8, 7, 6])
assert result1 == [1, 2, 3]
assert result2 == [4, 5]
assert result3 == [6, 7, 8, 9]
"""
test_file = temp_dir / "test_async_sorter.py"
test_file.write_text(test_code_multiple_calls)
func = FunctionToOptimize(
function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True
)
# First instrument the source module with async decorators
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.BEHAVIOR
)
assert source_success
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
source_file.write_text(instrumented_source)
import ast
tree = ast.parse(test_code_multiple_calls)
call_positions = []
for node in ast.walk(tree):
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or (
hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter"
):
call_positions.append(CodePosition(node.lineno, node.col_offset))
assert len(call_positions) == 4
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
)
assert success
assert instrumented_test_code is not None
assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code
# Count occurrences of each line_id to verify numbering
line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'")
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
def test_async_behavior_decorator_return_values_and_test_ids():
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
import asyncio
import os
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
@codeflash_behavior_async
async def test_async_multiply(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.001) # Small delay to simulate async work
return x * y
test_env = {
"CODEFLASH_TEST_MODULE": "test_module",
"CODEFLASH_TEST_CLASS": None,
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
"CODEFLASH_CURRENT_LINE_ID": "0",
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "2",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
result = asyncio.run(test_async_multiply(6, 7))
assert result == 42, f"Expected return value 42, got {result}"
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
# Verify database exists and has data
assert db_path.exists(), f"Database file not created at {db_path}"
# Read and verify database contents
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
row = rows[0]
(
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) = row
assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'"
assert test_class is None, f"Expected test_class None, got '{test_class}'"
assert test_function == "test_async_multiply_function", (
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
)
assert function_name == "test_async_multiply", (
f"Expected function_name 'test_async_multiply', got '{function_name}'"
)
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'"
assert verification_type == "function_call", (
f"Expected verification_type 'function_call', got '{verification_type}'"
)
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, actual_return_value = unpickled_data
assert args == (6, 7), f"Expected args (6, 7), got {args}"
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}"
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
def test_async_decorator_comprehensive_return_values_and_test_ids():
import asyncio
import os
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file
@codeflash_behavior_async
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
"""Async function that multiplies x*y then adds z."""
await asyncio.sleep(0.001)
result = (x * y) + z
return result
test_env = {
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
"CODEFLASH_CURRENT_LINE_ID": "3",
"CODEFLASH_LOOP_INDEX": "2",
"CODEFLASH_TEST_ITERATION": "3",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
test_cases = [
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
{"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
]
results = []
for test_case in test_cases:
result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"]))
results.append(result)
# Verify each return value is exactly correct
assert result == test_case["expected"], (
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
)
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
assert db_path.exists(), f"Database not created at {db_path}"
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute(
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
)
rows = cur.fetchall()
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
for i, (
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) in enumerate(rows):
assert test_module == "test_comprehensive_module", (
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
)
assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
assert test_function == "test_comprehensive_async_function", (
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
)
assert function_name == "async_multiply_add", (
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
)
assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}"
assert verification_type == "function_call", (
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
)
expected_iteration_id = f"3_{i}"
assert iteration_id == expected_iteration_id, (
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
)
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
expected_args = test_cases[i]["args"]
expected_kwargs = test_cases[i]["kwargs"]
expected_return = test_cases[i]["expected"]
assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}"
assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
assert actual_return_value == expected_return, (
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
)
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]

View file

@ -6,9 +6,11 @@ from pathlib import Path
import pytest import pytest
from codeflash.context.unused_definition_remover import detect_unused_helper_functions from codeflash.context.unused_definition_remover import detect_unused_helper_functions
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeStringsMarkdown, FunctionParent from codeflash.models.models import CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig from codeflash.verification.verification_utils import TestConfig
from codeflash.context.unused_definition_remover import revert_unused_helper_functions
@pytest.fixture @pytest.fixture
@ -225,6 +227,15 @@ def helper_function_2(x):
def test_no_unused_helpers_no_revert(temp_project): def test_no_unused_helpers_no_revert(temp_project):
"""Test that when all helpers are still used, nothing is reverted.""" """Test that when all helpers are still used, nothing is reverted."""
temp_dir, main_file, test_cfg = temp_project temp_dir, main_file, test_cfg = temp_project
# Store original content to verify nothing changes
original_content = main_file.read_text()
revert_unused_helper_functions(temp_dir, [], {})
# Verify the file content remains unchanged
assert main_file.read_text() == original_content, "File should remain unchanged when no helpers to revert"
# Optimized version that still calls both helpers # Optimized version that still calls both helpers
optimized_code = """ optimized_code = """
@ -308,17 +319,23 @@ def helper_function_1(x):
def helper_function_2(x): def helper_function_2(x):
\"\"\"Second helper function.\"\"\" \"\"\"Second helper function.\"\"\"
return x * 3 return x * 3
def helper_function_1(y): # Duplicate name to test line 575
\"\"\"Overloaded helper function.\"\"\"
return y + 10
""") """)
# Optimized version that only calls one helper # Optimized version that only calls one helper with aliased import
optimized_code = """ optimized_code = """
```python:main.py ```python:main.py
from helpers import helper_function_1 from helpers import helper_function_1 as h1
import helpers as h_module
def entrypoint_function(n): def entrypoint_function(n):
\"\"\"Optimized function that only calls one helper.\"\"\" \"\"\"Optimized function that only calls one helper with aliasing.\"\"\"
result1 = helper_function_1(n) result1 = h1(n) # Using aliased import
return result1 + n * 3 # Inlined helper_function_2 # Inlined helper_function_2 functionality: n * 3
return result1 + n * 3 # Fully inlined helper_function_2
``` ```
""" """
@ -1462,113 +1479,44 @@ class MathUtils:
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
def test_unused_helper_detection_with_duplicated_function_name_in_different_classes(): def test_async_entrypoint_with_async_helpers():
"""Test detection when helpers are called via module.function style.""" """Test that unused async helper functions are correctly detected when entrypoint is async."""
temp_dir = Path(tempfile.mkdtemp()) temp_dir = Path(tempfile.mkdtemp())
try: try:
# Main file # Main file with async entrypoint and async helpers
main_file = temp_dir / "main.py" main_file = temp_dir / "main.py"
main_file.write_text("""from __future__ import annotations main_file.write_text("""
import json async def async_helper_1(x):
from helpers import replace_quotes_with_backticks, simplify_worktree_paths \"\"\"First async helper function.\"\"\"
from dataclasses import asdict, dataclass return x * 2
@dataclass async def async_helper_2(x):
class LspMessage: \"\"\"Second async helper function.\"\"\"
return x * 3
def serialize(self) -> str: async def async_entrypoint(n):
data = self._loop_through(asdict(self)) \"\"\"Async entrypoint function that calls async helpers.\"\"\"
# Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it result1 = await async_helper_1(n)
ordered = {"type": self.type(), **data} result2 = await async_helper_2(n)
return ( return result1 + result2
message_delimiter
+ json.dumps(ordered)
+ message_delimiter
)
@dataclass
class LspMarkdownMessage(LspMessage):
def serialize(self) -> str:
self.markdown = simplify_worktree_paths(self.markdown)
self.markdown = replace_quotes_with_backticks(self.markdown)
return super().serialize()
""") """)
# Helpers file # Optimized version that only calls one async helper
helpers_file = temp_dir / "helpers.py"
helpers_file.write_text("""def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
path_in_msg = worktree_path_regex.search(msg)
if path_in_msg:
last_part_of_path = path_in_msg.group(0).split("/")[-1]
if highlight:
last_part_of_path = f"`{last_part_of_path}`"
return msg.replace(path_in_msg.group(0), last_part_of_path)
return msg
def replace_quotes_with_backticks(text: str) -> str:
# double-quoted strings
text = _double_quote_pat.sub(r"`\1`", text)
# single-quoted strings
return _single_quote_pat.sub(r"`\1`", text)
""")
# Optimized version that only uses add_numbers
optimized_code = """ optimized_code = """
```python:main.py ```python:main.py
from __future__ import annotations async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
import json async def async_helper_2(x):
from dataclasses import asdict, dataclass \"\"\"Second async helper function - should be unused.\"\"\"
return x * 3
from codeflash.lsp.helpers import (replace_quotes_with_backticks, async def async_entrypoint(n):
simplify_worktree_paths) \"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
result1 = await async_helper_1(n)
return result1 + n * 3 # Inlined async_helper_2
@dataclass
class LspMessage:
def serialize(self) -> str:
# Use local variable to minimize lookup costs and avoid unnecessary dictionary unpacking
data = self._loop_through(asdict(self))
msg_type = self.type()
ordered = {'type': msg_type}
ordered.update(data)
return (
message_delimiter
+ json.dumps(ordered)
+ message_delimiter # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
)
@dataclass
class LspMarkdownMessage(LspMessage):
def serialize(self) -> str:
# Side effect required, must preserve for behavioral correctness
self.markdown = simplify_worktree_paths(self.markdown)
self.markdown = replace_quotes_with_backticks(self.markdown)
return super().serialize()
```
```python:helpers.py
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
m = worktree_path_regex.search(msg)
if m:
# More efficient way to get last path part
last_part_of_path = m.group(0).rpartition('/')[-1]
if highlight:
last_part_of_path = f"`{last_part_of_path}`"
return msg.replace(m.group(0), last_part_of_path)
return msg
def replace_quotes_with_backticks(text: str) -> str:
# Efficient string substitution, reduces intermediate string allocations
return _single_quote_pat.sub(
r"`\1`",
_double_quote_pat.sub(r"`\1`", text),
)
``` ```
""" """
@ -1581,31 +1529,543 @@ def replace_quotes_with_backticks(text: str) -> str:
pytest_cmd="pytest", pytest_cmd="pytest",
) )
# Create FunctionToOptimize instance # Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize( function_to_optimize = FunctionToOptimize(
file_path=main_file, function_name="serialize", qualified_name="serialize", parents=[ file_path=main_file,
FunctionParent(name="LspMarkdownMessage", type="ClassDef"), function_name="async_entrypoint",
] parents=[],
is_async=True
) )
# Create function optimizer
optimizer = FunctionOptimizer( optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize, function_to_optimize=function_to_optimize,
test_cfg=test_cfg, test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(), function_to_optimize_source_code=main_file.read_text(),
) )
# Get original code context
ctx_result = optimizer.get_code_optimization_context() ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap() code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect async_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers} unused_names = {uh.qualified_name for uh in unused_helpers}
assert len(unused_names) == 0 # no unused helpers expected_unused = {"async_helper_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_sync_entrypoint_with_async_helpers():
"""Test that unused async helper functions are detected when entrypoint is sync."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with sync entrypoint and async helpers
main_file = temp_dir / "main.py"
main_file.write_text("""
import asyncio
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function.\"\"\"
return x * 3
def sync_entrypoint(n):
\"\"\"Sync entrypoint function that calls async helpers.\"\"\"
result1 = asyncio.run(async_helper_1(n))
result2 = asyncio.run(async_helper_2(n))
return result1 + result2
""")
# Optimized version that only calls one async helper
optimized_code = """
```python:main.py
import asyncio
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function - should be unused.\"\"\"
return x * 3
def sync_entrypoint(n):
\"\"\"Optimized sync entrypoint that only calls one async helper.\"\"\"
result1 = asyncio.run(async_helper_1(n))
return result1 + n * 3 # Inlined async_helper_2
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for sync function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="sync_entrypoint",
parents=[]
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect async_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"async_helper_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_mixed_sync_and_async_helpers():
"""Test detection when both sync and async helpers are mixed."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with mixed sync and async helpers
main_file = temp_dir / "main.py"
main_file.write_text("""
import asyncio
def sync_helper_1(x):
\"\"\"Sync helper function.\"\"\"
return x * 2
async def async_helper_1(x):
\"\"\"Async helper function.\"\"\"
return x * 3
def sync_helper_2(x):
\"\"\"Another sync helper function.\"\"\"
return x * 4
async def async_helper_2(x):
\"\"\"Another async helper function.\"\"\"
return x * 5
async def mixed_entrypoint(n):
\"\"\"Async entrypoint function that calls both sync and async helpers.\"\"\"
sync_result = sync_helper_1(n)
async_result = await async_helper_1(n)
sync_result2 = sync_helper_2(n)
async_result2 = await async_helper_2(n)
return sync_result + async_result + sync_result2 + async_result2
""")
# Optimized version that only calls some helpers
optimized_code = """
```python:main.py
import asyncio
def sync_helper_1(x):
\"\"\"Sync helper function.\"\"\"
return x * 2
async def async_helper_1(x):
\"\"\"Async helper function.\"\"\"
return x * 3
def sync_helper_2(x):
\"\"\"Another sync helper function - should be unused.\"\"\"
return x * 4
async def async_helper_2(x):
\"\"\"Another async helper function - should be unused.\"\"\"
return x * 5
async def mixed_entrypoint(n):
\"\"\"Optimized async entrypoint that only calls some helpers.\"\"\"
sync_result = sync_helper_1(n)
async_result = await async_helper_1(n)
return sync_result + async_result + n * 4 + n * 5 # Inlined both helper_2 functions
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="mixed_entrypoint",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect both sync_helper_2 and async_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"sync_helper_2", "async_helper_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_class_methods():
"""Test unused async method detection in classes."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with class containing async methods
main_file = temp_dir / "main.py"
main_file.write_text("""
class AsyncProcessor:
async def entrypoint_method(self, n):
\"\"\"Async main method that calls async helper methods.\"\"\"
result1 = await self.async_helper_method_1(n)
result2 = await self.async_helper_method_2(n)
return result1 + result2
async def async_helper_method_1(self, x):
\"\"\"First async helper method.\"\"\"
return x * 2
async def async_helper_method_2(self, x):
\"\"\"Second async helper method.\"\"\"
return x * 3
def sync_helper_method(self, x):
\"\"\"Sync helper method.\"\"\"
return x * 4
""")
# Optimized version that only calls one async helper
optimized_code = """
```python:main.py
class AsyncProcessor:
async def entrypoint_method(self, n):
\"\"\"Optimized async method that only calls one helper.\"\"\"
result1 = await self.async_helper_method_1(n)
return result1 + n * 3 # Inlined async_helper_method_2
async def async_helper_method_1(self, x):
\"\"\"First async helper method.\"\"\"
return x * 2
async def async_helper_method_2(self, x):
\"\"\"Second async helper method - should be unused.\"\"\"
return x * 3
def sync_helper_method(self, x):
\"\"\"Sync helper method - should be unused.\"\"\"
return x * 4
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async class method
from codeflash.models.models import FunctionParent
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="entrypoint_method",
parents=[FunctionParent(name="AsyncProcessor", type="ClassDef")],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect async_helper_method_2 as unused (sync_helper_method may not be discovered as helper)
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"AsyncProcessor.async_helper_method_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_helper_revert_functionality():
"""Test that unused async helper functions are correctly reverted to original definitions."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with async functions
main_file = temp_dir / "main.py"
main_file.write_text("""
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function.\"\"\"
return x * 3
async def async_entrypoint(n):
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
result1 = await async_helper_1(n)
result2 = await async_helper_2(n)
return result1 + result2
""")
# Optimized version that only calls one helper and modifies the unused one
optimized_code = """
```python:main.py
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Modified async helper function - should be reverted.\"\"\"
return x * 10 # This change should be reverted
async def async_entrypoint(n):
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
result1 = await async_helper_1(n)
return result1 + n * 3 # Inlined async_helper_2
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="async_entrypoint",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Store original helper code
original_helper_code = {main_file: main_file.read_text()}
# Apply optimization and test reversion
optimizer.replace_function_and_helpers_with_optimized_code(
code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code
)
# Check final file content
final_content = main_file.read_text()
# The entrypoint should be optimized
assert "result1 + n * 3" in final_content, "Async entrypoint function should be optimized"
# async_helper_2 should be reverted to original (return x * 3, not x * 10)
assert "return x * 3" in final_content, "async_helper_2 should be reverted to original"
assert "return x * 10" not in final_content, "async_helper_2 should not contain the modified version"
# async_helper_1 should remain (it's still called)
assert "async def async_helper_1(x):" in final_content, "async_helper_1 should still exist"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_generators_and_coroutines():
"""Test detection with async generators and coroutines."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with async generators and coroutines
main_file = temp_dir / "main.py"
main_file.write_text("""
import asyncio
async def async_generator_helper(n):
\"\"\"Async generator helper.\"\"\"
for i in range(n):
yield i * 2
async def coroutine_helper(x):
\"\"\"Coroutine helper.\"\"\"
await asyncio.sleep(0.1)
return x * 3
async def another_coroutine_helper(x):
\"\"\"Another coroutine helper.\"\"\"
await asyncio.sleep(0.1)
return x * 4
async def async_entrypoint_with_generators(n):
\"\"\"Async entrypoint function that uses generators and coroutines.\"\"\"
results = []
async for value in async_generator_helper(n):
results.append(value)
final_result = await coroutine_helper(sum(results))
another_result = await another_coroutine_helper(n)
return final_result + another_result
""")
# Optimized version that doesn't use one of the coroutines
optimized_code = """
```python:main.py
import asyncio
async def async_generator_helper(n):
\"\"\"Async generator helper.\"\"\"
for i in range(n):
yield i * 2
async def coroutine_helper(x):
\"\"\"Coroutine helper.\"\"\"
await asyncio.sleep(0.1)
return x * 3
async def another_coroutine_helper(x):
\"\"\"Another coroutine helper - should be unused.\"\"\"
await asyncio.sleep(0.1)
return x * 4
async def async_entrypoint_with_generators(n):
\"\"\"Optimized async entrypoint that inlines one coroutine.\"\"\"
results = []
async for value in async_generator_helper(n):
results.append(value)
final_result = await coroutine_helper(sum(results))
return final_result + n * 4 # Inlined another_coroutine_helper
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="async_entrypoint_with_generators",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect another_coroutine_helper as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"another_coroutine_helper"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally: finally:
# Cleanup # Cleanup
import shutil import shutil
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)

1032
uv.lock

File diff suppressed because it is too large Load diff