mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into part-1-windows-fixes
This commit is contained in:
commit
f978a406bb
45 changed files with 5834 additions and 838 deletions
69
.github/workflows/e2e-async.yaml
vendored
Normal file
69
.github/workflows/e2e-async.yaml
vendored
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
name: E2E - Async
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**' # Trigger for all paths
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
async-optimization:
|
||||
# Dynamically determine if environment is needed only when workflow files change and contributor is external
|
||||
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 10
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: |
|
||||
uv sync
|
||||
|
||||
- name: Run Codeflash to optimize async code
|
||||
id: optimize_async_code
|
||||
run: |
|
||||
uv run python tests/scripts/end_to_end_test_async.py
|
||||
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 300
|
||||
EXPECTED_IMPROVEMENT_PCT: 70
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 300
|
||||
EXPECTED_IMPROVEMENT_PCT: 40
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
|
|
|
|||
43
code_to_optimize/async_bubble_sort.py
Normal file
43
code_to_optimize/async_bubble_sort.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
"""
|
||||
Async bubble sort implementation for testing.
|
||||
"""
|
||||
print("codeflash stdout: Async sorting list")
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
print(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
class AsyncBubbleSorter:
|
||||
"""Class with async sorting method for testing."""
|
||||
|
||||
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
"""
|
||||
Async bubble sort implementation within a class.
|
||||
"""
|
||||
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
|
||||
|
||||
# Add some async delay
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
return result
|
||||
16
code_to_optimize/code_directories/async_e2e/main.py
Normal file
16
code_to_optimize/code_directories/async_e2e/main.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import time
|
||||
import asyncio
|
||||
|
||||
|
||||
async def retry_with_backoff(func, max_retries=3):
|
||||
if max_retries < 1:
|
||||
raise ValueError("max_retries must be at least 1")
|
||||
last_exception = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(0.0001 * attempt)
|
||||
raise last_exception
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
[tool.codeflash]
|
||||
disable-telemetry = true
|
||||
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
|
||||
module-root = "."
|
||||
test-framework = "pytest"
|
||||
tests-root = "tests"
|
||||
|
|
@ -102,6 +102,8 @@ class AiServiceClient:
|
|||
trace_id: str,
|
||||
num_candidates: int = 10,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
*,
|
||||
is_async: bool = False,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -133,6 +135,7 @@ class AiServiceClient:
|
|||
"repo_owner": git_repo_owner,
|
||||
"repo_name": git_repo_name,
|
||||
"n_candidates": N_CANDIDATES_EFFECTIVE,
|
||||
"is_async": is_async,
|
||||
}
|
||||
|
||||
logger.info("!lsp|Generating optimized candidates…")
|
||||
|
|
@ -299,6 +302,9 @@ class AiServiceClient:
|
|||
annotated_tests: str,
|
||||
optimization_id: str,
|
||||
original_explanation: str,
|
||||
original_throughput: str | None = None,
|
||||
optimized_throughput: str | None = None,
|
||||
throughput_improvement: str | None = None,
|
||||
) -> str:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -315,6 +321,9 @@ class AiServiceClient:
|
|||
- annotated_tests: str - test functions annotated with runtime
|
||||
- optimization_id: str - unique id of opt candidate
|
||||
- original_explanation: str - original_explanation generated for the opt candidate
|
||||
- original_throughput: str | None - throughput for the baseline code (operations per second)
|
||||
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
|
||||
- throughput_improvement: str | None - throughput improvement percentage
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -334,6 +343,9 @@ class AiServiceClient:
|
|||
"optimization_id": optimization_id,
|
||||
"original_explanation": original_explanation,
|
||||
"dependency_code": dependency_code,
|
||||
"original_throughput": original_throughput,
|
||||
"optimized_throughput": optimized_throughput,
|
||||
"throughput_improvement": throughput_improvement,
|
||||
}
|
||||
logger.info("loading|Generating explanation")
|
||||
console.rule()
|
||||
|
|
@ -488,6 +500,7 @@ class AiServiceClient:
|
|||
"test_index": test_index,
|
||||
"python_version": platform.python_version(),
|
||||
"codeflash_version": codeflash_version,
|
||||
"is_async": function_to_optimize.is_async,
|
||||
}
|
||||
try:
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from argparse import SUPPRESS, ArgumentParser, Namespace
|
||||
|
|
@ -96,6 +97,12 @@ def parse_args() -> Namespace:
|
|||
)
|
||||
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
|
||||
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
|
||||
parser.add_argument(
|
||||
"--async",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enable optimization of async functions. By default, async functions are excluded from optimization.",
|
||||
)
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
|
|
@ -139,6 +146,14 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
|
|||
if env_utils.is_ci():
|
||||
args.no_pr = True
|
||||
|
||||
if getattr(args, "async", False) and importlib.util.find_spec("pytest_asyncio") is None:
|
||||
logger.warning(
|
||||
"Warning: The --async flag requires pytest-asyncio to be installed.\n"
|
||||
"Please install it using:\n"
|
||||
' pip install "codeflash[asyncio]"'
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -272,6 +272,8 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
if child.module is None:
|
||||
continue
|
||||
module = self.get_full_dotted_name(child.module)
|
||||
if isinstance(child.names, cst.ImportStar):
|
||||
continue
|
||||
for alias in child.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
name = alias.name.value
|
||||
|
|
@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
return transformed_module.code
|
||||
|
||||
|
||||
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
||||
try:
|
||||
module_path = module_name.replace(".", "/")
|
||||
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
|
||||
|
||||
module_file = None
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
module_file = path
|
||||
break
|
||||
|
||||
if module_file is None:
|
||||
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
|
||||
return set()
|
||||
|
||||
with module_file.open(encoding="utf8") as f:
|
||||
module_code = f.read()
|
||||
|
||||
tree = ast.parse(module_code)
|
||||
|
||||
all_names = None
|
||||
for node in ast.walk(tree):
|
||||
if (
|
||||
isinstance(node, ast.Assign)
|
||||
and len(node.targets) == 1
|
||||
and isinstance(node.targets[0], ast.Name)
|
||||
and node.targets[0].id == "__all__"
|
||||
):
|
||||
if isinstance(node.value, (ast.List, ast.Tuple)):
|
||||
all_names = []
|
||||
for elt in node.value.elts:
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
|
||||
all_names.append(elt.value)
|
||||
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
|
||||
all_names.append(elt.s)
|
||||
break
|
||||
|
||||
if all_names is not None:
|
||||
return set(all_names)
|
||||
|
||||
public_names = set()
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
if not node.name.startswith("_"):
|
||||
public_names.add(node.name)
|
||||
elif isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and not target.id.startswith("_"):
|
||||
public_names.add(target.id)
|
||||
elif isinstance(node, ast.AnnAssign):
|
||||
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
|
||||
public_names.add(node.target.id)
|
||||
elif isinstance(node, ast.Import) or (
|
||||
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
|
||||
):
|
||||
for alias in node.names:
|
||||
name = alias.asname or alias.name
|
||||
if not name.startswith("_"):
|
||||
public_names.add(name)
|
||||
|
||||
return public_names # noqa: TRY300
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving star import for {module_name}: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def add_needed_imports_from_module(
|
||||
src_module_code: str,
|
||||
dst_module_code: str,
|
||||
|
|
@ -468,9 +537,23 @@ def add_needed_imports_from_module(
|
|||
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
|
||||
):
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
|
||||
# Handle star imports by resolving them to actual symbol names
|
||||
if obj == "*":
|
||||
resolved_symbols = resolve_star_import(mod, project_root)
|
||||
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
|
||||
|
||||
for symbol in resolved_symbols:
|
||||
if (
|
||||
f"{mod}.{symbol}" not in helper_functions_fqn
|
||||
and f"{mod}.{symbol}" not in dotted_import_collector.imports
|
||||
):
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
|
||||
else:
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
|
|
|
|||
|
|
@ -269,14 +269,6 @@ def validate_python_code(code: str) -> str:
|
|||
return code
|
||||
|
||||
|
||||
def has_any_async_functions(code: str) -> bool:
|
||||
try:
|
||||
module = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return False
|
||||
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))
|
||||
|
||||
|
||||
def cleanup_paths(paths: list[Path]) -> None:
|
||||
for path in paths:
|
||||
if path and path.exists():
|
||||
|
|
|
|||
167
codeflash/code_utils/codeflash_wrap_decorator.py
Normal file
167
codeflash/code_utils/codeflash_wrap_decorator.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
import dill as pickle
|
||||
|
||||
|
||||
class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py
|
||||
FUNCTION_CALL = (
|
||||
"function_call" # Correctness verification for a test function, checks input values and output values)
|
||||
)
|
||||
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
|
||||
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
|
||||
def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
||||
test_module = os.environ["CODEFLASH_TEST_MODULE"]
|
||||
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
|
||||
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
|
||||
|
||||
if test_module and test_function:
|
||||
return (test_module, test_class if test_class else None, test_function)
|
||||
|
||||
raise RuntimeError(
|
||||
"Test context environment variables not set - ensure tests are run through codeflash test runner"
|
||||
)
|
||||
|
||||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
codeflash_con = sqlite3.connect(db_path)
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
|
||||
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
|
||||
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs) # coroutine creation has some overhead, though it is very small
|
||||
counter = loop.time()
|
||||
return_value = await ret # let's measure the actual execution time of the code
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
print(f"!######{test_stdout_tag}######!")
|
||||
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
codeflash_con.close()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_performance_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return async_wrapper
|
||||
|
|
@ -3,6 +3,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
|||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
N_CANDIDATES = 5
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
MAX_TEST_FUNCTION_RUNS = 50
|
||||
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
|
||||
N_TESTS_TO_GENERATE = 2
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
|
|||
"""Extract the single dependent function from the code context excluding the main function."""
|
||||
ast_tree = ast.parse(code_context.testgen_context_code)
|
||||
|
||||
dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
|
||||
dependent_functions = {
|
||||
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
}
|
||||
|
||||
if main_function in dependent_functions:
|
||||
dependent_functions.discard(main_function)
|
||||
|
|
|
|||
|
|
@ -32,9 +32,11 @@ class CommentMapper(ast.NodeVisitor):
|
|||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
self.context_stack.append(node.name)
|
||||
for inner_node in ast.walk(node):
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.FunctionDef):
|
||||
self.visit_FunctionDef(inner_node)
|
||||
elif isinstance(inner_node, ast.AsyncFunctionDef):
|
||||
self.visit_AsyncFunctionDef(inner_node)
|
||||
self.context_stack.pop()
|
||||
return node
|
||||
|
||||
|
|
@ -50,6 +52,14 @@ class CommentMapper(ast.NodeVisitor):
|
|||
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
self._process_function_def_common(node)
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
self._process_function_def_common(node)
|
||||
return node
|
||||
|
||||
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
|
||||
self.context_stack.append(node.name)
|
||||
i = len(node.body) - 1
|
||||
test_qualified_name = ".".join(self.context_stack)
|
||||
|
|
@ -60,8 +70,9 @@ class CommentMapper(ast.NodeVisitor):
|
|||
j = len(line_node.body) - 1
|
||||
while j >= 0:
|
||||
compound_line_node: ast.stmt = line_node.body[j]
|
||||
internal_node: ast.AST
|
||||
for internal_node in ast.walk(compound_line_node):
|
||||
nodes_to_check = [compound_line_node]
|
||||
nodes_to_check.extend(getattr(compound_line_node, "body", []))
|
||||
for internal_node in nodes_to_check:
|
||||
if isinstance(internal_node, (ast.stmt, ast.Assign)):
|
||||
inv_id = str(i) + "_" + str(j)
|
||||
match_key = key + "#" + inv_id
|
||||
|
|
@ -75,7 +86,6 @@ class CommentMapper(ast.NodeVisitor):
|
|||
self.results[line_node.lineno] = self.get_comment(match_key)
|
||||
i -= 1
|
||||
self.context_stack.pop()
|
||||
return node
|
||||
|
||||
|
||||
def get_fn_call_linenos(
|
||||
|
|
@ -197,23 +207,41 @@ def add_runtime_comments_to_generated_tests(
|
|||
def remove_functions_from_generated_tests(
|
||||
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
|
||||
) -> GeneratedTestsList:
|
||||
# Pre-compile patterns for all function names to remove
|
||||
function_patterns = _compile_function_patterns(test_functions_to_remove)
|
||||
new_generated_tests = []
|
||||
|
||||
for generated_test in generated_tests.generated_tests:
|
||||
for test_function in test_functions_to_remove:
|
||||
function_pattern = re.compile(
|
||||
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
source = generated_test.generated_original_test_source
|
||||
|
||||
match = function_pattern.search(generated_test.generated_original_test_source)
|
||||
|
||||
if match is None or "@pytest.mark.parametrize" in match.group(0):
|
||||
continue
|
||||
|
||||
generated_test.generated_original_test_source = function_pattern.sub(
|
||||
"", generated_test.generated_original_test_source
|
||||
)
|
||||
# Apply all patterns without redundant searches
|
||||
for pattern in function_patterns:
|
||||
# Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls
|
||||
for match in pattern.finditer(source):
|
||||
# Skip if "@pytest.mark.parametrize" present
|
||||
# Only the matched function's code is targeted
|
||||
if "@pytest.mark.parametrize" in match.group(0):
|
||||
continue
|
||||
# Remove function from source
|
||||
# If match, remove the function by substitution in the source
|
||||
# Replace using start/end indices for efficiency
|
||||
start, end = match.span()
|
||||
source = source[:start] + source[end:]
|
||||
# After removal, break since .finditer() is from left to right, and only one match expected per function in source
|
||||
break
|
||||
|
||||
generated_test.generated_original_test_source = source
|
||||
new_generated_tests.append(generated_test)
|
||||
|
||||
return GeneratedTestsList(generated_tests=new_generated_tests)
|
||||
|
||||
|
||||
# Pre-compile all function removal regexes upfront for efficiency.
|
||||
def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]:
|
||||
return [
|
||||
re.compile(
|
||||
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
for func in test_functions_to_remove
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import isort
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
|
|
@ -77,6 +78,10 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
call_node = node
|
||||
if isinstance(node.func, ast.Name):
|
||||
function_name = node.func.id
|
||||
|
||||
if self.function_object.is_async:
|
||||
return [test_node]
|
||||
|
||||
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
|
||||
node.args = [
|
||||
ast.Name(id=function_name, ctx=ast.Load()),
|
||||
|
|
@ -98,6 +103,9 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
if isinstance(node.func, ast.Attribute):
|
||||
function_to_test = node.func.attr
|
||||
if function_to_test == self.function_object.function_name:
|
||||
if self.function_object.is_async:
|
||||
return [test_node]
|
||||
|
||||
function_name = ast.unparse(node.func)
|
||||
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
|
||||
node.args = [
|
||||
|
|
@ -289,6 +297,168 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
return node
|
||||
|
||||
|
||||
class AsyncCallInstrumenter(ast.NodeTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
module_path: str,
|
||||
test_framework: str,
|
||||
call_positions: list[CodePosition],
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> None:
|
||||
self.mode = mode
|
||||
self.function_object = function
|
||||
self.class_name = None
|
||||
self.only_function_name = function.function_name
|
||||
self.module_path = module_path
|
||||
self.test_framework = test_framework
|
||||
self.call_positions = call_positions
|
||||
self.did_instrument = False
|
||||
# Track function call count per test function
|
||||
self.async_call_counter: dict[str, int] = {}
|
||||
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
|
||||
self.class_name = function.top_level_parent_name
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
# Add timeout decorator for unittest test classes if needed
|
||||
if self.test_framework == "unittest":
|
||||
timeout_decorator = ast.Call(
|
||||
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
|
||||
args=[ast.Constant(value=15)],
|
||||
keywords=[],
|
||||
)
|
||||
for item in node.body:
|
||||
if (
|
||||
isinstance(item, ast.FunctionDef)
|
||||
and item.name.startswith("test_")
|
||||
and not any(
|
||||
isinstance(d, ast.Call)
|
||||
and isinstance(d.func, ast.Name)
|
||||
and d.func.id == "timeout_decorator.timeout"
|
||||
for d in item.decorator_list
|
||||
)
|
||||
):
|
||||
item.decorator_list.append(timeout_decorator)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
if not node.name.startswith("test_"):
|
||||
return node
|
||||
|
||||
return self._process_test_function(node)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
# Only process test functions
|
||||
if not node.name.startswith("test_"):
|
||||
return node
|
||||
|
||||
return self._process_test_function(node)
|
||||
|
||||
def _process_test_function(
|
||||
self, node: ast.AsyncFunctionDef | ast.FunctionDef
|
||||
) -> ast.AsyncFunctionDef | ast.FunctionDef:
|
||||
# Optimize the search for decorator presence
|
||||
if self.test_framework == "unittest":
|
||||
found_timeout = False
|
||||
for d in node.decorator_list:
|
||||
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
|
||||
if isinstance(d, ast.Call):
|
||||
f = d.func
|
||||
# Avoid attribute lookup if f is not ast.Name
|
||||
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
|
||||
found_timeout = True
|
||||
break
|
||||
if not found_timeout:
|
||||
timeout_decorator = ast.Call(
|
||||
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
|
||||
args=[ast.Constant(value=15)],
|
||||
keywords=[],
|
||||
)
|
||||
node.decorator_list.append(timeout_decorator)
|
||||
|
||||
# Initialize counter for this test function
|
||||
if node.name not in self.async_call_counter:
|
||||
self.async_call_counter[node.name] = 0
|
||||
|
||||
new_body = []
|
||||
|
||||
# Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
|
||||
for _i, stmt in enumerate(node.body):
|
||||
transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt)
|
||||
|
||||
if added_env_assignment:
|
||||
current_call_index = self.async_call_counter[node.name]
|
||||
self.async_call_counter[node.name] += 1
|
||||
|
||||
env_assignment = ast.Assign(
|
||||
targets=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"),
|
||||
ctx=ast.Store(),
|
||||
)
|
||||
],
|
||||
value=ast.Constant(value=f"{current_call_index}"),
|
||||
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
|
||||
)
|
||||
new_body.append(env_assignment)
|
||||
self.did_instrument = True
|
||||
|
||||
new_body.append(transformed_stmt)
|
||||
|
||||
node.body = new_body
|
||||
return node
|
||||
|
||||
def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
|
||||
for node in ast.walk(stmt):
|
||||
if (
|
||||
isinstance(node, ast.Await)
|
||||
and isinstance(node.value, ast.Call)
|
||||
and self._is_target_call(node.value)
|
||||
and self._call_in_positions(node.value)
|
||||
):
|
||||
# Check if this call is in one of our target positions
|
||||
return stmt, True # Return original statement but signal we added env var
|
||||
|
||||
return stmt, False
|
||||
|
||||
def _is_target_call(self, call_node: ast.Call) -> bool:
|
||||
"""Check if this call node is calling our target async function."""
|
||||
if isinstance(call_node.func, ast.Name):
|
||||
return call_node.func.id == self.function_object.function_name
|
||||
if isinstance(call_node.func, ast.Attribute):
|
||||
return call_node.func.attr == self.function_object.function_name
|
||||
return False
|
||||
|
||||
def _call_in_positions(self, call_node: ast.Call) -> bool:
|
||||
if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"):
|
||||
return False
|
||||
|
||||
return node_in_call_position(call_node, self.call_positions)
|
||||
|
||||
# Optimized version: only walk child nodes for Await
|
||||
def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]:
|
||||
# Stack-based DFS, manual for relevant Await nodes
|
||||
stack = [stmt]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
# Favor direct ast.Await detection
|
||||
if isinstance(node, ast.Await):
|
||||
val = node.value
|
||||
if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val):
|
||||
return stmt, True
|
||||
# Use _fields instead of ast.walk for less allocations
|
||||
for fname in getattr(node, "_fields", ()):
|
||||
child = getattr(node, fname, None)
|
||||
if isinstance(child, list):
|
||||
stack.extend(child)
|
||||
elif isinstance(child, ast.AST):
|
||||
stack.append(child)
|
||||
return stmt, False
|
||||
|
||||
|
||||
class FunctionImportedAsVisitor(ast.NodeVisitor):
|
||||
"""Checks if a function has been imported as an alias. We only care about the alias then.
|
||||
|
||||
|
|
@ -316,6 +486,7 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
|
|||
file_path=self.function.file_path,
|
||||
starting_line=self.function.starting_line,
|
||||
ending_line=self.function.ending_line,
|
||||
is_async=self.function.is_async,
|
||||
)
|
||||
else:
|
||||
self.imported_as = FunctionToOptimize(
|
||||
|
|
@ -324,9 +495,69 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
|
|||
file_path=self.function.file_path,
|
||||
starting_line=self.function.starting_line,
|
||||
ending_line=self.function.ending_line,
|
||||
is_async=self.function.is_async,
|
||||
)
|
||||
|
||||
|
||||
def instrument_source_module_with_async_decorators(
|
||||
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
|
||||
) -> tuple[bool, str | None]:
|
||||
if not function_to_optimize.is_async:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
with source_path.open(encoding="utf8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
|
||||
|
||||
if decorator_added:
|
||||
return True, modified_code
|
||||
|
||||
except Exception:
|
||||
return False, None
|
||||
else:
|
||||
return False, None
|
||||
|
||||
|
||||
def inject_async_profiling_into_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: list[CodePosition],
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
tests_project_root: Path,
|
||||
test_framework: str,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling for async function calls by setting environment variables before each call."""
|
||||
with test_path.open(encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
|
||||
try:
|
||||
tree = ast.parse(test_code)
|
||||
except SyntaxError:
|
||||
logger.exception(f"Syntax error in code in file - {test_path}")
|
||||
return False, None
|
||||
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
|
||||
test_module_path = module_name_from_file_path(test_path, tests_project_root)
|
||||
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
|
||||
import_visitor.visit(tree)
|
||||
func = import_visitor.imported_as
|
||||
|
||||
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
|
||||
tree = async_instrumenter.visit(tree)
|
||||
|
||||
if not async_instrumenter.did_instrument:
|
||||
return False, None
|
||||
|
||||
# Add necessary imports
|
||||
new_imports = [ast.Import(names=[ast.alias(name="os")])]
|
||||
if test_framework == "unittest":
|
||||
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
|
||||
|
||||
tree.body = [*new_imports, *tree.body]
|
||||
return True, isort.code(ast.unparse(tree), float_to_top=True)
|
||||
|
||||
|
||||
def inject_profiling_into_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: list[CodePosition],
|
||||
|
|
@ -335,6 +566,11 @@ def inject_profiling_into_existing_test(
|
|||
test_framework: str,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
if function_to_optimize.is_async:
|
||||
return inject_async_profiling_into_existing_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
|
||||
)
|
||||
|
||||
with test_path.open(encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
try:
|
||||
|
|
@ -342,7 +578,7 @@ def inject_profiling_into_existing_test(
|
|||
except SyntaxError:
|
||||
logger.exception(f"Syntax error in code in file - {test_path}")
|
||||
return False, None
|
||||
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
|
||||
|
||||
test_module_path = module_name_from_file_path(test_path, tests_project_root)
|
||||
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
|
||||
import_visitor.visit(tree)
|
||||
|
|
@ -360,7 +596,9 @@ def inject_profiling_into_existing_test(
|
|||
)
|
||||
if test_framework == "unittest" and platform.system() != "Windows":
|
||||
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
|
||||
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
|
||||
additional_functions = [create_wrapper_function(mode)]
|
||||
|
||||
tree.body = [*new_imports, *additional_functions, *tree.body]
|
||||
return True, isort.code(ast.unparse(tree), float_to_top=True)
|
||||
|
||||
|
||||
|
|
@ -741,3 +979,162 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
decorator_list=[],
|
||||
returns=None,
|
||||
)
|
||||
|
||||
|
||||
class AsyncDecoratorAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds async decorator to async function definitions."""
|
||||
|
||||
def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
|
||||
"""Initialize the transformer.
|
||||
|
||||
Args:
|
||||
----
|
||||
function: The FunctionToOptimize object representing the target async function.
|
||||
mode: The testing mode to determine which decorator to apply.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.function = function
|
||||
self.mode = mode
|
||||
self.qualified_name_parts = function.qualified_name.split(".")
|
||||
self.context_stack = []
|
||||
self.added_decorator = False
|
||||
|
||||
# Choose decorator based on mode
|
||||
self.decorator_name = (
|
||||
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
# Track when we enter a function
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
# Check if this is an async function and matches our target
|
||||
if original_node.asynchronous is not None and self.context_stack == self.qualified_name_parts:
|
||||
# Check if the decorator is already present
|
||||
has_decorator = any(
|
||||
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
|
||||
)
|
||||
|
||||
# Only add the decorator if it's not already there
|
||||
if not has_decorator:
|
||||
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
|
||||
|
||||
# Add our new decorator to the existing decorators
|
||||
updated_decorators = [new_decorator, *list(updated_node.decorators)]
|
||||
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))
|
||||
self.added_decorator = True
|
||||
|
||||
# Pop the context when we leave a function
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool:
|
||||
"""Check if a decorator matches our target decorator name."""
|
||||
if isinstance(decorator_node, cst.Name):
|
||||
return decorator_node.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
}
|
||||
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
|
||||
return decorator_node.func.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
}
|
||||
return False
|
||||
|
||||
|
||||
class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds the import for async decorators."""
|
||||
|
||||
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
|
||||
self.mode = mode
|
||||
self.has_import = False
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
# Check if the async decorator import is already present
|
||||
if (
|
||||
isinstance(node.module, cst.Attribute)
|
||||
and isinstance(node.module.value, cst.Attribute)
|
||||
and isinstance(node.module.value.value, cst.Name)
|
||||
and node.module.value.value.value == "codeflash"
|
||||
and node.module.value.attr.value == "code_utils"
|
||||
and node.module.attr.value == "codeflash_wrap_decorator"
|
||||
and not isinstance(node.names, cst.ImportStar)
|
||||
):
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
for import_alias in node.names:
|
||||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
||||
# Choose import based on mode
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
|
||||
# Parse the import statement into a CST node
|
||||
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
|
||||
|
||||
# Add the import to the module's body
|
||||
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
|
||||
|
||||
|
||||
def add_async_decorator_to_function(
|
||||
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
|
||||
) -> tuple[str, bool]:
|
||||
"""Add async decorator to an async function definition.
|
||||
|
||||
Args:
|
||||
----
|
||||
source_code: The source code to modify.
|
||||
function: The FunctionToOptimize object representing the target async function.
|
||||
mode: The testing mode to determine which decorator to apply.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Tuple of (modified_source_code, was_decorator_added).
|
||||
|
||||
"""
|
||||
if not function.is_async:
|
||||
return source_code, False
|
||||
|
||||
try:
|
||||
module = cst.parse_module(source_code)
|
||||
|
||||
# Add the decorator to the function
|
||||
decorator_transformer = AsyncDecoratorAdder(function, mode)
|
||||
module = module.visit(decorator_transformer)
|
||||
|
||||
# Add the import if decorator was added
|
||||
if decorator_transformer.added_decorator:
|
||||
import_transformer = AsyncDecoratorImportAdder(mode)
|
||||
module = module.visit(import_transformer)
|
||||
|
||||
return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
|
||||
return source_code, False
|
||||
|
||||
|
||||
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:
|
||||
instrumented_filename = f"instrumented_{source_path.name}"
|
||||
return temp_dir / instrumented_filename
|
||||
|
|
|
|||
|
|
@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(
|
|||
|
||||
def get_first_top_level_function_or_method_ast(
|
||||
function_name: str, parents: list[FunctionParent], node: ast.AST
|
||||
) -> ast.FunctionDef | None:
|
||||
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
|
||||
if not parents:
|
||||
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
|
||||
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
|
||||
if result is not None:
|
||||
return result
|
||||
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
|
||||
if parents[0].type == "ClassDef" and (
|
||||
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
|
||||
):
|
||||
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
|
||||
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
|
||||
if result is not None:
|
||||
return result
|
||||
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor):
|
|||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
is_async=bool(node.asynchronous),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
|
||||
# Check if the async function has a return statement and add it to the list
|
||||
if function_has_return_statement(node) and not function_is_a_property(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
|
||||
)
|
||||
)
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> None:
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
|
||||
|
|
@ -122,6 +132,7 @@ class FunctionToOptimize:
|
|||
parents: A list of parent scopes, which could be classes or functions.
|
||||
starting_line: The starting line number of the function in the file.
|
||||
ending_line: The ending line number of the function in the file.
|
||||
is_async: Whether this function is defined as async.
|
||||
|
||||
The qualified_name property provides the full name of the function, including
|
||||
any parent class or function names. The qualified_name_with_modules_from_root
|
||||
|
|
@ -134,6 +145,7 @@ class FunctionToOptimize:
|
|||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
starting_line: Optional[int] = None
|
||||
ending_line: Optional[int] = None
|
||||
is_async: bool = False
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
|
|
@ -147,7 +159,11 @@ class FunctionToOptimize:
|
|||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
|
||||
if not self.parents:
|
||||
return self.function_name
|
||||
# Join all parent names with dots to handle nested classes properly
|
||||
parent_path = ".".join(parent.name for parent in self.parents)
|
||||
return f"{parent_path}.{self.function_name}"
|
||||
|
||||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
|
@ -163,6 +179,8 @@ def get_functions_to_optimize(
|
|||
project_root: Path,
|
||||
module_root: Path,
|
||||
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
|
||||
*,
|
||||
enable_async: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
|
||||
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
|
||||
"Only one of optimize_all, replay_test, or file should be provided"
|
||||
|
|
@ -216,7 +234,13 @@ def get_functions_to_optimize(
|
|||
ph("cli-optimizing-git-diff")
|
||||
functions = get_functions_within_git_diff(uncommitted_changes=False)
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
|
||||
functions,
|
||||
test_cfg.tests_root,
|
||||
ignore_paths,
|
||||
project_root,
|
||||
module_root,
|
||||
previous_checkpoint_functions,
|
||||
enable_async=enable_async,
|
||||
)
|
||||
|
||||
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
|
|
@ -411,11 +435,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
)
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
if self.class_name is None and node.name == self.function_name:
|
||||
self.is_top_level = True
|
||||
self.function_has_args = any(
|
||||
(
|
||||
bool(node.args.args),
|
||||
bool(node.args.kwonlyargs),
|
||||
bool(node.args.kwarg),
|
||||
bool(node.args.posonlyargs),
|
||||
bool(node.args.vararg),
|
||||
)
|
||||
)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
# iterate over the class methods
|
||||
if node.name == self.class_name:
|
||||
for body_node in node.body:
|
||||
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
|
||||
if (
|
||||
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and body_node.name == self.function_name
|
||||
):
|
||||
self.is_top_level = True
|
||||
if any(
|
||||
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
|
||||
|
|
@ -433,7 +473,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
# This way, if we don't have the class name, we can still find the static method
|
||||
for body_node in node.body:
|
||||
if (
|
||||
isinstance(body_node, ast.FunctionDef)
|
||||
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and body_node.name == self.function_name
|
||||
and body_node.lineno in {self.line_no, self.line_no + 1}
|
||||
and any(
|
||||
|
|
@ -535,7 +575,9 @@ def filter_functions(
|
|||
project_root: Path,
|
||||
module_root: Path,
|
||||
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
|
||||
disable_logs: bool = False, # noqa: FBT001, FBT002
|
||||
*,
|
||||
disable_logs: bool = False,
|
||||
enable_async: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
blocklist_funcs = get_blocklisted_functions()
|
||||
|
|
@ -555,6 +597,7 @@ def filter_functions(
|
|||
submodule_ignored_paths_count: int = 0
|
||||
blocklist_funcs_removed_count: int = 0
|
||||
previous_checkpoint_functions_removed_count: int = 0
|
||||
async_functions_removed_count: int = 0
|
||||
tests_root_str = str(tests_root)
|
||||
module_root_str = str(module_root)
|
||||
|
||||
|
|
@ -610,6 +653,15 @@ def filter_functions(
|
|||
functions_tmp.append(function)
|
||||
_functions = functions_tmp
|
||||
|
||||
if not enable_async:
|
||||
functions_tmp = []
|
||||
for function in _functions:
|
||||
if function.is_async:
|
||||
async_functions_removed_count += 1
|
||||
continue
|
||||
functions_tmp.append(function)
|
||||
_functions = functions_tmp
|
||||
|
||||
filtered_modified_functions[file_path] = _functions
|
||||
functions_count += len(_functions)
|
||||
|
||||
|
|
@ -623,6 +675,7 @@ def filter_functions(
|
|||
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
|
||||
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
|
||||
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
|
||||
"Async functions removed": (async_functions_removed_count, "bright_magenta"),
|
||||
}
|
||||
tree = Tree(Text("Ignored functions and files", style="bold"))
|
||||
for label, (count, color) in log_info.items():
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class BestOptimization(BaseModel):
|
|||
winning_benchmarking_test_results: TestResults
|
||||
winning_replay_benchmarking_test_results: Optional[TestResults] = None
|
||||
line_profiler_test_results: dict
|
||||
async_throughput: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -277,6 +278,7 @@ class OptimizedCandidateResult(BaseModel):
|
|||
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
|
||||
optimization_candidate_index: int
|
||||
total_candidate_timing: int
|
||||
async_throughput: Optional[int] = None
|
||||
|
||||
|
||||
class GeneratedTests(BaseModel):
|
||||
|
|
@ -383,6 +385,7 @@ class OriginalCodeBaseline(BaseModel):
|
|||
line_profile_results: dict
|
||||
runtime: int
|
||||
coverage_results: Optional[CoverageData]
|
||||
async_throughput: Optional[int] = None
|
||||
|
||||
|
||||
class CoverageStatus(Enum):
|
||||
|
|
@ -545,6 +548,7 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
# also we don't support deletion of test results elements - caution is advised
|
||||
test_results: list[FunctionTestInvocation] = []
|
||||
test_result_idx: dict[str, int] = {}
|
||||
perf_stdout: Optional[str] = None
|
||||
|
||||
def add(self, function_test_invocation: FunctionTestInvocation) -> None:
|
||||
unique_id = function_test_invocation.unique_invocation_loop_id
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from codeflash.code_utils.code_utils import (
|
|||
diff_length,
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
has_any_async_functions,
|
||||
module_name_from_file_path,
|
||||
restore_conftest,
|
||||
unified_diff_strings,
|
||||
|
|
@ -85,14 +84,20 @@ from codeflash.models.models import (
|
|||
TestType,
|
||||
)
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.critic import (
|
||||
coverage_critic,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
throughput_gain,
|
||||
)
|
||||
from codeflash.result.explanation import Explanation
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.concolic_testing import generate_concolic_tests
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
from codeflash.verification.parse_test_output import parse_test_results
|
||||
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
|
@ -199,7 +204,7 @@ class FunctionOptimizer:
|
|||
test_cfg: TestConfig,
|
||||
function_to_optimize_source_code: str = "",
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
|
||||
aiservice_client: AiServiceClient | None = None,
|
||||
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
|
|
@ -259,11 +264,6 @@ class FunctionOptimizer:
|
|||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
async_code = any(
|
||||
has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings
|
||||
)
|
||||
if async_code:
|
||||
return Failure("Codeflash does not support async functions in the code to optimize.")
|
||||
# Random here means that we still attempt optimization with a fractional chance to see if
|
||||
# last time we could not find an optimization, maybe this time we do.
|
||||
# Random is before as a performance optimization, swapping the two 'and' statements has the same effect
|
||||
|
|
@ -588,7 +588,11 @@ class FunctionOptimizer:
|
|||
tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛")
|
||||
benchmark_tree = None
|
||||
if speedup_critic(
|
||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
|
||||
candidate_result,
|
||||
original_code_baseline.runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
) and quantity_of_tests_critic(candidate_result):
|
||||
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
|
||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||
|
|
@ -599,6 +603,17 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
if (
|
||||
original_code_baseline.async_throughput is not None
|
||||
and candidate_result.async_throughput is not None
|
||||
):
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_code_baseline.async_throughput,
|
||||
optimized_throughput=candidate_result.async_throughput,
|
||||
)
|
||||
tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions")
|
||||
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
|
||||
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
|
||||
line_profile_test_results = self.line_profiler_step(
|
||||
code_context=code_context,
|
||||
original_helper_code=original_helper_code,
|
||||
|
|
@ -634,6 +649,7 @@ class FunctionOptimizer:
|
|||
replay_performance_gain=replay_perf_gain if self.args.benchmark else None,
|
||||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
async_throughput=candidate_result.async_throughput,
|
||||
)
|
||||
valid_optimizations.append(best_optimization)
|
||||
# queue corresponding refined optimization for best optimization
|
||||
|
|
@ -658,6 +674,15 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
if (
|
||||
original_code_baseline.async_throughput is not None
|
||||
and candidate_result.async_throughput is not None
|
||||
):
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_code_baseline.async_throughput,
|
||||
optimized_throughput=candidate_result.async_throughput,
|
||||
)
|
||||
tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%")
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree)))
|
||||
|
|
@ -701,6 +726,7 @@ class FunctionOptimizer:
|
|||
replay_performance_gain=valid_opt.replay_performance_gain,
|
||||
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
|
||||
async_throughput=valid_opt.async_throughput,
|
||||
)
|
||||
valid_candidates_with_shorter_code.append(new_best_opt)
|
||||
diff_lens_list.append(
|
||||
|
|
@ -1080,6 +1106,7 @@ class FunctionOptimizer:
|
|||
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
|
||||
n_candidates,
|
||||
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
)
|
||||
future_candidates_exp = None
|
||||
|
||||
|
|
@ -1095,6 +1122,7 @@ class FunctionOptimizer:
|
|||
self.function_trace_id[:-4] + "EXP1",
|
||||
n_candidates,
|
||||
ExperimentMetadata(id=self.experiment_id, group="experiment"),
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
)
|
||||
futures.append(future_candidates_exp)
|
||||
|
||||
|
|
@ -1281,6 +1309,8 @@ class FunctionOptimizer:
|
|||
function_name=function_to_optimize_qualified_name,
|
||||
file_path=self.function_to_optimize.file_path,
|
||||
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_async_throughput=best_optimization.async_throughput,
|
||||
)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
|
|
@ -1363,6 +1393,23 @@ class FunctionOptimizer:
|
|||
original_runtimes_all=original_runtime_by_test,
|
||||
optimized_runtimes_all=optimized_runtime_by_test,
|
||||
)
|
||||
original_throughput_str = None
|
||||
optimized_throughput_str = None
|
||||
throughput_improvement_str = None
|
||||
|
||||
if (
|
||||
self.function_to_optimize.is_async
|
||||
and original_code_baseline.async_throughput is not None
|
||||
and best_optimization.async_throughput is not None
|
||||
):
|
||||
original_throughput_str = f"{original_code_baseline.async_throughput} operations/second"
|
||||
optimized_throughput_str = f"{best_optimization.async_throughput} operations/second"
|
||||
throughput_improvement_value = throughput_gain(
|
||||
original_throughput=original_code_baseline.async_throughput,
|
||||
optimized_throughput=best_optimization.async_throughput,
|
||||
)
|
||||
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
|
||||
|
||||
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
|
||||
source_code=code_context.read_writable_code.flat,
|
||||
dependency_code=code_context.read_only_context_code,
|
||||
|
|
@ -1376,6 +1423,9 @@ class FunctionOptimizer:
|
|||
annotated_tests=generated_tests_str,
|
||||
optimization_id=best_optimization.candidate.optimization_id,
|
||||
original_explanation=best_optimization.candidate.explanation,
|
||||
original_throughput=original_throughput_str,
|
||||
optimized_throughput=optimized_throughput_str,
|
||||
throughput_improvement=throughput_improvement_str,
|
||||
)
|
||||
new_explanation = Explanation(
|
||||
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
|
||||
|
|
@ -1386,6 +1436,8 @@ class FunctionOptimizer:
|
|||
function_name=explanation.function_name,
|
||||
file_path=explanation.file_path,
|
||||
benchmark_details=explanation.benchmark_details,
|
||||
original_async_throughput=explanation.original_async_throughput,
|
||||
best_async_throughput=explanation.best_async_throughput,
|
||||
)
|
||||
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
|
||||
|
||||
|
|
@ -1476,6 +1528,17 @@ class FunctionOptimizer:
|
|||
|
||||
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
|
||||
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
|
||||
|
||||
# Instrument codeflash capture
|
||||
with progress_bar("Running tests to establish original code behavior..."):
|
||||
try:
|
||||
|
|
@ -1515,15 +1578,38 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
with progress_bar("Running performance benchmarks..."):
|
||||
benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
)
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
instrument_source_module_with_async_decorators,
|
||||
)
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(
|
||||
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
)
|
||||
finally:
|
||||
if self.function_to_optimize.is_async:
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
self.function_to_optimize.file_path,
|
||||
)
|
||||
else:
|
||||
benchmarking_results = TestResults()
|
||||
start_time: float = time.time()
|
||||
|
|
@ -1577,6 +1663,14 @@ class FunctionOptimizer:
|
|||
console.rule()
|
||||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
|
||||
async_throughput = None
|
||||
if self.function_to_optimize.is_async:
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
|
||||
console.rule()
|
||||
|
||||
if self.args.benchmark:
|
||||
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -1590,6 +1684,7 @@ class FunctionOptimizer:
|
|||
runtime=total_timing,
|
||||
coverage_results=coverage_results,
|
||||
line_profile_results=line_profile_results,
|
||||
async_throughput=async_throughput,
|
||||
),
|
||||
functions_to_remove,
|
||||
)
|
||||
|
|
@ -1618,6 +1713,21 @@ class FunctionOptimizer:
|
|||
candidate_helper_code = {}
|
||||
for module_abspath in original_helper_code:
|
||||
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
instrument_source_module_with_async_decorators,
|
||||
)
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(
|
||||
f"Applied async behavioral instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(
|
||||
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
|
||||
|
|
@ -1655,14 +1765,37 @@ class FunctionOptimizer:
|
|||
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
|
||||
|
||||
if test_framework == "pytest":
|
||||
candidate_benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=optimization_candidate_index,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
)
|
||||
# For async functions, instrument at definition site for performance benchmarking
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
instrument_source_module_with_async_decorators,
|
||||
)
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(
|
||||
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
|
||||
)
|
||||
|
||||
try:
|
||||
candidate_benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=optimization_candidate_index,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
)
|
||||
finally:
|
||||
# Restore original source if we instrumented it
|
||||
if self.function_to_optimize.is_async:
|
||||
self.write_code_and_helpers(
|
||||
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
loop_count = (
|
||||
max(all_loop_indices)
|
||||
if (
|
||||
|
|
@ -1698,6 +1831,14 @@ class FunctionOptimizer:
|
|||
console.rule()
|
||||
|
||||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
|
||||
candidate_async_throughput = None
|
||||
if self.function_to_optimize.is_async:
|
||||
candidate_async_throughput = calculate_function_throughput_from_test_results(
|
||||
candidate_benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
|
||||
|
||||
if self.args.benchmark:
|
||||
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -1717,6 +1858,7 @@ class FunctionOptimizer:
|
|||
else None,
|
||||
optimization_candidate_index=optimization_candidate_index,
|
||||
total_candidate_timing=total_candidate_timing,
|
||||
async_throughput=candidate_async_throughput,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1808,8 +1950,10 @@ class FunctionOptimizer:
|
|||
coverage_database_file=coverage_database_file,
|
||||
coverage_config_file=coverage_config_file,
|
||||
)
|
||||
else:
|
||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
if testing_type == TestingMode.PERFORMANCE:
|
||||
results.perf_stdout = run_result.stdout
|
||||
return results, coverage_results
|
||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
return results, coverage_results
|
||||
|
||||
def submit_test_generation_tasks(
|
||||
|
|
|
|||
|
|
@ -134,12 +134,13 @@ class Optimizer:
|
|||
project_root=self.args.project_root,
|
||||
module_root=self.args.module_root,
|
||||
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
|
||||
enable_async=getattr(self.args, "async", False),
|
||||
)
|
||||
|
||||
def create_function_optimizer(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_source_code: str | None = "",
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ from codeflash.code_utils.config_consts import (
|
|||
COVERAGE_THRESHOLD,
|
||||
MIN_IMPROVEMENT_THRESHOLD,
|
||||
MIN_TESTCASE_PASSED_THRESHOLD,
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
|
||||
)
|
||||
from codeflash.models.test_type import TestType
|
||||
from codeflash.models import models
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
|
|
@ -25,20 +26,41 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) ->
|
|||
return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns
|
||||
|
||||
|
||||
def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float:
|
||||
"""Calculate the throughput gain of an optimized code over the original code.
|
||||
|
||||
This value multiplied by 100 gives the percentage improvement in throughput.
|
||||
For throughput, higher values are better (more executions per time period).
|
||||
"""
|
||||
if original_throughput == 0:
|
||||
return 0.0
|
||||
return (optimized_throughput - original_throughput) / original_throughput
|
||||
|
||||
|
||||
def speedup_critic(
|
||||
candidate_result: OptimizedCandidateResult,
|
||||
original_code_runtime: int,
|
||||
best_runtime_until_now: int | None,
|
||||
*,
|
||||
disable_gh_action_noise: bool = False,
|
||||
original_async_throughput: int | None = None,
|
||||
best_throughput_until_now: int | None = None,
|
||||
) -> bool:
|
||||
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
|
||||
|
||||
Ensure that the optimization is actually faster than the original code, above the noise floor.
|
||||
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
|
||||
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
|
||||
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
|
||||
Evaluates both runtime performance and async throughput improvements.
|
||||
|
||||
For runtime performance:
|
||||
- Ensures the optimization is actually faster than the original code, above the noise floor.
|
||||
- The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
|
||||
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
|
||||
- The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance.
|
||||
|
||||
For async throughput (when available):
|
||||
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
- Throughput improvements complement runtime improvements for async functions
|
||||
"""
|
||||
# Runtime performance evaluation
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
if not disable_gh_action_noise and env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
|
||||
|
|
@ -46,10 +68,31 @@ def speedup_critic(
|
|||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
|
||||
)
|
||||
if best_runtime_until_now is None:
|
||||
# collect all optimizations with this
|
||||
return bool(perf_gain > noise_floor)
|
||||
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
|
||||
runtime_improved = perf_gain > noise_floor
|
||||
|
||||
# Check runtime comparison with best so far
|
||||
runtime_is_best = best_runtime_until_now is None or candidate_result.best_test_runtime < best_runtime_until_now
|
||||
|
||||
throughput_improved = True # Default to True if no throughput data
|
||||
throughput_is_best = True # Default to True if no throughput data
|
||||
|
||||
if original_async_throughput is not None and candidate_result.async_throughput is not None:
|
||||
if original_async_throughput > 0:
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput
|
||||
)
|
||||
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
|
||||
throughput_is_best = (
|
||||
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
|
||||
)
|
||||
|
||||
if original_async_throughput is not None and candidate_result.async_throughput is not None:
|
||||
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
|
||||
throughput_acceptance = throughput_improved and throughput_is_best
|
||||
runtime_acceptance = runtime_improved and runtime_is_best
|
||||
return throughput_acceptance or runtime_acceptance
|
||||
return runtime_improved and runtime_is_best
|
||||
|
||||
|
||||
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
|
||||
|
|
@ -63,7 +106,7 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
|
|||
if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD:
|
||||
return True
|
||||
# If one or more tests passed, check if least one of them was a successful REPLAY_TEST
|
||||
return bool(pass_count >= 1 and report[TestType.REPLAY_TEST]["passed"] >= 1)
|
||||
return bool(pass_count >= 1 and report[models.TestType.REPLAY_TEST]["passed"] >= 1) # type: ignore # noqa: PGH003
|
||||
|
||||
|
||||
def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from rich.table import Table
|
|||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import BenchmarkDetail, TestResults
|
||||
from codeflash.result.critic import throughput_gain
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
|
|
@ -24,9 +25,28 @@ class Explanation:
|
|||
function_name: str
|
||||
file_path: Path
|
||||
benchmark_details: Optional[list[BenchmarkDetail]] = None
|
||||
original_async_throughput: Optional[int] = None
|
||||
best_async_throughput: Optional[int] = None
|
||||
|
||||
@property
|
||||
def perf_improvement_line(self) -> str:
|
||||
runtime_improvement = self.speedup
|
||||
|
||||
if (
|
||||
self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
and self.original_async_throughput > 0
|
||||
):
|
||||
throughput_improvement = throughput_gain(
|
||||
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
|
||||
)
|
||||
|
||||
# Use throughput metrics if throughput improvement is better or runtime got worse
|
||||
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
|
||||
throughput_pct = f"{throughput_improvement * 100:,.0f}%"
|
||||
throughput_x = f"{throughput_improvement + 1:,.2f}x"
|
||||
return f"{throughput_pct} improvement ({throughput_x} faster)."
|
||||
|
||||
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
|
||||
|
||||
@property
|
||||
|
|
@ -46,6 +66,23 @@ class Explanation:
|
|||
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
|
||||
original_runtime_human = humanize_runtime(self.original_runtime_ns)
|
||||
best_runtime_human = humanize_runtime(self.best_runtime_ns)
|
||||
|
||||
# Determine if we're showing throughput or runtime improvements
|
||||
runtime_improvement = self.speedup
|
||||
is_using_throughput_metric = False
|
||||
|
||||
if (
|
||||
self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
and self.original_async_throughput > 0
|
||||
):
|
||||
throughput_improvement = throughput_gain(
|
||||
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
|
||||
)
|
||||
|
||||
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
|
||||
is_using_throughput_metric = True
|
||||
|
||||
benchmark_info = ""
|
||||
|
||||
if self.benchmark_details:
|
||||
|
|
@ -86,13 +123,18 @@ class Explanation:
|
|||
console.print(table)
|
||||
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
|
||||
|
||||
test_report = self.winning_behavior_test_results.get_test_pass_fail_report_by_type()
|
||||
test_report_str = TestResults.report_to_string(test_report)
|
||||
if is_using_throughput_metric:
|
||||
performance_description = (
|
||||
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
|
||||
f"(runtime: {original_runtime_human} → {best_runtime_human})\n\n"
|
||||
)
|
||||
else:
|
||||
performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
|
||||
|
||||
return (
|
||||
f"Optimized {self.function_name} in {self.file_path}\n"
|
||||
f"{self.perf_improvement_line}\n"
|
||||
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
|
||||
+ performance_description
|
||||
+ (benchmark_info if benchmark_info else "")
|
||||
+ self.raw_explanation_message
|
||||
+ " \n\n"
|
||||
|
|
@ -101,7 +143,7 @@ class Explanation:
|
|||
""
|
||||
if is_LSP_enabled()
|
||||
else "The new optimized code was tested for correctness. The results are listed below.\n"
|
||||
+ test_report_str
|
||||
f"{TestResults.report_to_string(self.winning_behavior_test_results.get_test_pass_fail_report_by_type())}\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ class CoverageUtils:
|
|||
|
||||
cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True)
|
||||
|
||||
if not database_path.stat().st_size or not database_path.exists():
|
||||
if not database_path.exists() or not database_path.stat().st_size:
|
||||
logger.debug(f"Coverage database {database_path} is empty or does not exist")
|
||||
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
|
||||
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
|
||||
return CoverageData.create_empty(source_code_path, function_name, code_context)
|
||||
cov.load()
|
||||
|
||||
reporter = JsonReporter(cov)
|
||||
|
|
@ -51,8 +51,8 @@ class CoverageUtils:
|
|||
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
|
||||
except NoDataError:
|
||||
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
|
||||
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
|
||||
with temp_json_file.open(encoding="utf-8") as f:
|
||||
return CoverageData.create_empty(source_code_path, function_name, code_context)
|
||||
with temp_json_file.open() as f:
|
||||
original_coverage_data = json.load(f)
|
||||
|
||||
coverage_data, status = CoverageUtils._parse_coverage_file(temp_json_file, source_code_path)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,30 @@ matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)
|
|||
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
|
||||
|
||||
|
||||
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
|
||||
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
|
||||
|
||||
|
||||
def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int:
|
||||
"""Calculate function throughput from TestResults by extracting performance stdout.
|
||||
|
||||
A completed execution is defined as having both a start tag and matching end tag from performance wrappers.
|
||||
Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$!
|
||||
End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######!
|
||||
"""
|
||||
start_matches = start_pattern.findall(test_results.perf_stdout or "")
|
||||
end_matches = end_pattern.findall(test_results.perf_stdout or "")
|
||||
|
||||
end_matches_truncated = [end_match[:5] for end_match in end_matches]
|
||||
end_matches_set = set(end_matches_truncated)
|
||||
|
||||
function_throughput = 0
|
||||
for start_match in start_matches:
|
||||
if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name:
|
||||
function_throughput += 1
|
||||
return function_throughput
|
||||
|
||||
|
||||
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
||||
test_results = TestResults()
|
||||
if not file_location.exists():
|
||||
|
|
|
|||
|
|
@ -450,3 +450,26 @@ class PytestLoops:
|
|||
metafunc.parametrize(
|
||||
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
|
||||
)
|
||||
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_runtest_setup(self, item: pytest.Item) -> None:
|
||||
"""Set test context environment variables before each test."""
|
||||
test_module_name = item.module.__name__ if item.module else "unknown_module"
|
||||
|
||||
test_class_name = None
|
||||
if item.cls:
|
||||
test_class_name = item.cls.__name__
|
||||
|
||||
test_function_name = item.name
|
||||
if "[" in test_function_name:
|
||||
test_function_name = test_function_name.split("[", 1)[0]
|
||||
|
||||
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
|
||||
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
|
||||
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
|
||||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
|
||||
"""Clean up test context environment variables after each test."""
|
||||
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
|
||||
os.environ.pop(var, None)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.16.7.post46.dev0+444ff121"
|
||||
__version__ = "0.16.7.post77.dev0+02f96e77"
|
||||
|
|
|
|||
|
|
@ -52,8 +52,14 @@ Homepage = "https://codeflash.ai"
|
|||
[project.scripts]
|
||||
codeflash = "codeflash.main:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
asyncio = [
|
||||
"pytest-asyncio>=1.2.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
{include-group = "asyncio"},
|
||||
"ipython>=8.12.0",
|
||||
"mypy>=1.13",
|
||||
"ruff>=0.7.0",
|
||||
|
|
@ -76,6 +82,9 @@ dev = [
|
|||
"uv>=0.6.2",
|
||||
"pre-commit>=4.2.0,<5",
|
||||
]
|
||||
asyncio = [
|
||||
"pytest-asyncio>=1.2.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["codeflash"]
|
||||
|
|
|
|||
28
tests/scripts/end_to_end_test_async.py
Normal file
28
tests/scripts/end_to_end_test_async.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
|
||||
|
||||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="main.py",
|
||||
expected_unit_tests=0,
|
||||
min_improvement_x=0.1,
|
||||
enable_async=True,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
expected_coverage=100.0,
|
||||
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
)
|
||||
],
|
||||
)
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e"
|
||||
).resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
|
||||
|
|
@ -11,7 +11,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
function_name="sorter",
|
||||
benchmarks_root=cwd / "tests" / "pytest" / "benchmarks",
|
||||
test_framework="pytest",
|
||||
min_improvement_x=1.0,
|
||||
min_improvement_x=0.70,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
file_path="bubble_sort.py",
|
||||
function_name="sorter",
|
||||
test_framework="pytest",
|
||||
min_improvement_x=1.0,
|
||||
min_improvement_x=0.70,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
|
|||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=3.0
|
||||
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.40
|
||||
)
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class TestConfig:
|
|||
trace_mode: bool = False
|
||||
coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
|
||||
benchmarks_root: Optional[pathlib.Path] = None
|
||||
enable_async: bool = False
|
||||
|
||||
|
||||
def clear_directory(directory_path: str | pathlib.Path) -> None:
|
||||
|
|
@ -134,6 +135,8 @@ def build_command(
|
|||
)
|
||||
if benchmarks_root:
|
||||
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
|
||||
if config.enable_async:
|
||||
base_command.append("--async")
|
||||
return base_command
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ from pathlib import Path
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
|
||||
import tempfile
|
||||
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
|
||||
import libcst as cst
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
def test_add_needed_imports_from_module0() -> None:
|
||||
src_module = '''import ast
|
||||
|
|
@ -349,3 +353,141 @@ class DbtAdapter(BaseAdapter):
|
|||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
|
||||
|
||||
def test_resolve_star_import_with_all_defined():
|
||||
"""Test resolve_star_import when __all__ is explicitly defined."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
# Create a test module with __all__ definition
|
||||
test_module.write_text('''
|
||||
__all__ = ['public_function', 'PublicClass']
|
||||
|
||||
def public_function():
|
||||
pass
|
||||
|
||||
def _private_function():
|
||||
pass
|
||||
|
||||
class PublicClass:
|
||||
pass
|
||||
|
||||
class AnotherPublicClass:
|
||||
"""Not in __all__ so should be excluded."""
|
||||
pass
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_function', 'PublicClass'}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
def test_resolve_star_import_without_all_defined():
|
||||
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
# Create a test module without __all__ definition
|
||||
test_module.write_text('''
|
||||
def public_func():
|
||||
pass
|
||||
|
||||
def _private_func():
|
||||
pass
|
||||
|
||||
class PublicClass:
|
||||
pass
|
||||
|
||||
PUBLIC_VAR = 42
|
||||
_private_var = 'secret'
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
def test_resolve_star_import_nonexistent_module():
|
||||
"""Test resolve_star_import with non-existent module - should return empty set."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
symbols = resolve_star_import('nonexistent_module', project_root)
|
||||
assert symbols == set()
|
||||
|
||||
|
||||
def test_dotted_import_collector_skips_star_imports():
|
||||
"""Test that DottedImportCollector correctly skips star imports."""
|
||||
code_with_star_import = '''
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import os
|
||||
'''
|
||||
|
||||
module = cst.parse_module(code_with_star_import)
|
||||
collector = DottedImportCollector()
|
||||
module.visit(collector)
|
||||
|
||||
# Should collect regular imports but skip the star import
|
||||
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
|
||||
assert collector.imports == expected_imports
|
||||
|
||||
|
||||
def test_add_needed_imports_with_star_import_resolution():
|
||||
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
# Create a source module that exports symbols
|
||||
src_module = project_root / 'source_module.py'
|
||||
src_module.write_text('''
|
||||
__all__ = ['UtilFunction', 'HelperClass']
|
||||
|
||||
def UtilFunction():
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
pass
|
||||
''')
|
||||
|
||||
# Create source code that uses star import
|
||||
src_code = '''
|
||||
from source_module import *
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
# Destination code that needs the imports resolved
|
||||
dst_code = '''
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
src_path = project_root / 'src.py'
|
||||
dst_path = project_root / 'dst.py'
|
||||
src_path.write_text(src_code)
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_code, dst_code, src_path, dst_path, project_root
|
||||
)
|
||||
|
||||
# The result should have individual imports instead of star import
|
||||
expected_result = '''from source_module import HelperClass, UtilFunction
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
assert result == expected_result
|
||||
|
|
|
|||
|
|
@ -1902,4 +1902,210 @@ def test_bubble_sort(input, expected_output):
|
|||
|
||||
# Check that comments were added
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert modified_source == expected
|
||||
assert modified_source == expected
|
||||
|
||||
def test_async_basic_runtime_comment_addition(self, test_config):
|
||||
"""Test basic functionality of adding runtime comments to async test functions."""
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_async_bubble_sort():
|
||||
codeflash_output = await async_bubble_sort([3, 1, 2])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
"""
|
||||
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py",
|
||||
)
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_invocation = self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0') # 500μs
|
||||
optimized_invocation = self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0') # 300μs
|
||||
|
||||
original_test_results.add(original_invocation)
|
||||
optimized_test_results.add(optimized_invocation)
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert "# 500μs -> 300μs" in modified_source
|
||||
assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
|
||||
|
||||
def test_async_multiple_test_functions(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_async_bubble_sort():
|
||||
codeflash_output = await async_quick_sort([3, 1, 2])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
|
||||
async def test_async_quick_sort():
|
||||
codeflash_output = await async_quick_sort([5, 2, 8])
|
||||
assert codeflash_output == [2, 5, 8]
|
||||
|
||||
def helper_function():
|
||||
return "not a test"
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_test_results.add(self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_async_quick_sort", 800_000, iteration_id='0'))
|
||||
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_quick_sort", 600_000, iteration_id='0'))
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
|
||||
assert "# 500μs -> 300μs" in modified_source
|
||||
assert "# 800μs -> 600μs" in modified_source
|
||||
assert (
|
||||
"helper_function():" in modified_source
|
||||
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
|
||||
)
|
||||
|
||||
def test_async_class_method(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = '''class TestAsyncClass:
|
||||
async def test_async_function(self):
|
||||
codeflash_output = await some_async_function()
|
||||
assert codeflash_output == expected
|
||||
'''
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
invocation_id = InvocationId(
|
||||
test_module_path="tests.test_module__unit_test_0",
|
||||
test_class_name="TestAsyncClass",
|
||||
test_function_name="test_async_function",
|
||||
function_getting_tested="some_async_function",
|
||||
iteration_id="0",
|
||||
)
|
||||
|
||||
original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds
|
||||
optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
expected_source = '''class TestAsyncClass:
|
||||
async def test_async_function(self):
|
||||
codeflash_output = await some_async_function() # 2.00s -> 1.00s (100% faster)
|
||||
assert codeflash_output == expected
|
||||
'''
|
||||
|
||||
assert len(result.generated_tests) == 1
|
||||
assert result.generated_tests[0].generated_original_test_source == expected_source
|
||||
|
||||
def test_async_mixed_sync_and_async_functions(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """def test_sync_function():
|
||||
codeflash_output = sync_function([1, 2, 3])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
|
||||
async def test_async_function():
|
||||
codeflash_output = await async_function([4, 5, 6])
|
||||
assert codeflash_output == [4, 5, 6]
|
||||
|
||||
def test_another_sync():
|
||||
result = another_sync_func()
|
||||
assert result is True
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
# Add test invocations for all test functions
|
||||
original_test_results.add(self.create_test_invocation("test_sync_function", 400_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_async_function", 600_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_another_sync", 200_000, iteration_id='0'))
|
||||
|
||||
optimized_test_results.add(self.create_test_invocation("test_sync_function", 200_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_function", 300_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_another_sync", 100_000, iteration_id='0'))
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
|
||||
assert "# 400μs -> 200μs" in modified_source
|
||||
assert "# 600μs -> 300μs" in modified_source
|
||||
assert "# 200μs -> 100μs" in modified_source
|
||||
|
||||
assert "async def test_async_function():" in modified_source
|
||||
assert "await async_function([4, 5, 6])" in modified_source
|
||||
|
||||
def test_async_complex_await_patterns(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_complex_async():
|
||||
# Multiple await calls
|
||||
result1 = await async_func1()
|
||||
codeflash_output = await async_func2(result1)
|
||||
result3 = await async_func3(codeflash_output)
|
||||
assert result3 == expected
|
||||
|
||||
# Await in context manager
|
||||
async with async_context() as ctx:
|
||||
final_result = await ctx.process()
|
||||
assert final_result is not None
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_test_results.add(self.create_test_invocation("test_complex_async", 750_000, iteration_id='1')) # 750μs
|
||||
optimized_test_results.add(self.create_test_invocation("test_complex_async", 450_000, iteration_id='1')) # 450μs
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert "# 750μs -> 450μs" in modified_source
|
||||
337
tests/test_async_function_discovery.py
Normal file
337
tests/test_async_function_discovery.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
find_all_functions_in_file,
|
||||
get_functions_to_optimize,
|
||||
inspect_top_level_functions_or_methods,
|
||||
)
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
def test_async_function_detection(temp_dir):
|
||||
async_function = """
|
||||
async def async_function_with_return():
|
||||
await some_async_operation()
|
||||
return 42
|
||||
|
||||
async def async_function_without_return():
|
||||
await some_async_operation()
|
||||
print("No return")
|
||||
|
||||
def regular_function():
|
||||
return 10
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_function)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_function_with_return" in function_names
|
||||
assert "regular_function" in function_names
|
||||
assert "async_function_without_return" not in function_names
|
||||
|
||||
|
||||
def test_async_method_in_class(temp_dir):
|
||||
code_with_async_method = """
|
||||
class AsyncClass:
|
||||
async def async_method(self):
|
||||
await self.do_something()
|
||||
return "result"
|
||||
|
||||
async def async_method_no_return(self):
|
||||
await self.do_something()
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
return "sync result"
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code_with_async_method)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
function_names = [fn.function_name for fn in found_functions]
|
||||
qualified_names = [fn.qualified_name for fn in found_functions]
|
||||
|
||||
assert "async_method" in function_names
|
||||
assert "AsyncClass.async_method" in qualified_names
|
||||
|
||||
assert "sync_method" in function_names
|
||||
assert "AsyncClass.sync_method" in qualified_names
|
||||
|
||||
assert "async_method_no_return" not in function_names
|
||||
|
||||
|
||||
def test_nested_async_functions(temp_dir):
|
||||
nested_async = """
|
||||
async def outer_async():
|
||||
async def inner_async():
|
||||
return "inner"
|
||||
|
||||
result = await inner_async()
|
||||
return result
|
||||
|
||||
def outer_sync():
|
||||
async def inner_async():
|
||||
return "inner from sync"
|
||||
|
||||
return inner_async
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(nested_async)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "outer_async" in function_names
|
||||
assert "outer_sync" in function_names
|
||||
assert "inner_async" not in function_names
|
||||
|
||||
|
||||
def test_async_staticmethod_and_classmethod(temp_dir):
|
||||
async_decorators = """
|
||||
class MyClass:
|
||||
@staticmethod
|
||||
async def async_static_method():
|
||||
await some_operation()
|
||||
return "static result"
|
||||
|
||||
@classmethod
|
||||
async def async_class_method(cls):
|
||||
await cls.some_operation()
|
||||
return "class result"
|
||||
|
||||
@property
|
||||
async def async_property(self):
|
||||
return await self.get_value()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_decorators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_static_method" in function_names
|
||||
assert "async_class_method" in function_names
|
||||
|
||||
assert "async_property" not in function_names
|
||||
|
||||
|
||||
def test_async_generator_functions(temp_dir):
|
||||
async_generators = """
|
||||
async def async_generator_with_return():
|
||||
for i in range(10):
|
||||
yield i
|
||||
return "done"
|
||||
|
||||
async def async_generator_no_return():
|
||||
for i in range(10):
|
||||
yield i
|
||||
|
||||
async def regular_async_with_return():
|
||||
result = await compute()
|
||||
return result
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_generators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_generator_with_return" in function_names
|
||||
assert "regular_async_with_return" in function_names
|
||||
assert "async_generator_no_return" not in function_names
|
||||
|
||||
|
||||
def test_inspect_async_top_level_functions(temp_dir):
|
||||
code = """
|
||||
async def top_level_async():
|
||||
return 42
|
||||
|
||||
class AsyncContainer:
|
||||
async def async_method(self):
|
||||
async def nested_async():
|
||||
return 1
|
||||
return await nested_async()
|
||||
|
||||
@staticmethod
|
||||
async def async_static():
|
||||
return "static"
|
||||
|
||||
@classmethod
|
||||
async def async_classmethod(cls):
|
||||
return "classmethod"
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code)
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
|
||||
assert result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
|
||||
assert not result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_staticmethod
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_classmethod
|
||||
|
||||
|
||||
def test_get_functions_to_optimize_with_async(temp_dir):
|
||||
mixed_code = """
|
||||
async def async_func_one():
|
||||
return await operation_one()
|
||||
|
||||
def sync_func_one():
|
||||
return operation_one()
|
||||
|
||||
async def async_func_two():
|
||||
print("no return")
|
||||
|
||||
class MixedClass:
|
||||
async def async_method(self):
|
||||
return await self.operation()
|
||||
|
||||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=file_path,
|
||||
only_get_this_function=None,
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[],
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
enable_async=True,
|
||||
)
|
||||
|
||||
assert functions_count == 4
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "async_func_one" in function_names
|
||||
assert "sync_func_one" in function_names
|
||||
assert "async_method" in function_names
|
||||
assert "sync_method" in function_names
|
||||
|
||||
assert "async_func_two" not in function_names
|
||||
|
||||
|
||||
def test_no_async_functions_finding(temp_dir):
|
||||
mixed_code = """
|
||||
async def async_func_one():
|
||||
return await operation_one()
|
||||
|
||||
def sync_func_one():
|
||||
return operation_one()
|
||||
|
||||
async def async_func_two():
|
||||
print("no return")
|
||||
|
||||
class MixedClass:
|
||||
async def async_method(self):
|
||||
return await self.operation()
|
||||
|
||||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=file_path,
|
||||
only_get_this_function=None,
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[],
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
enable_async=False,
|
||||
)
|
||||
|
||||
assert functions_count == 2
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "sync_func_one" in function_names
|
||||
assert "sync_method" in function_names
|
||||
assert "async_func_one" not in function_names
|
||||
assert "async_method" not in function_names
|
||||
|
||||
|
||||
def test_async_function_parents(temp_dir):
|
||||
complex_structure = """
|
||||
class OuterClass:
|
||||
async def outer_method(self):
|
||||
return 1
|
||||
|
||||
class InnerClass:
|
||||
async def inner_method(self):
|
||||
return 2
|
||||
|
||||
async def module_level_async():
|
||||
class LocalClass:
|
||||
async def local_method(self):
|
||||
return 3
|
||||
return LocalClass()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(complex_structure)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
|
||||
for fn in found_functions:
|
||||
if fn.function_name == "outer_method":
|
||||
assert len(fn.parents) == 1
|
||||
assert fn.parents[0].name == "OuterClass"
|
||||
assert fn.qualified_name == "OuterClass.outer_method"
|
||||
elif fn.function_name == "inner_method":
|
||||
assert len(fn.parents) == 2
|
||||
assert fn.parents[0].name == "OuterClass"
|
||||
assert fn.parents[1].name == "InnerClass"
|
||||
elif fn.function_name == "module_level_async":
|
||||
assert len(fn.parents) == 0
|
||||
assert fn.qualified_name == "module_level_async"
|
||||
1039
tests/test_async_run_and_parse_tests.py
Normal file
1039
tests/test_async_run_and_parse_tests.py
Normal file
File diff suppressed because it is too large
Load diff
285
tests/test_async_wrapper_sqlite_validation.py
Normal file
285
tests/test_async_wrapper_sqlite_validation.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
codeflash_performance_async,
|
||||
)
|
||||
from codeflash.verification.codeflash_capture import VerificationType
|
||||
|
||||
|
||||
class TestAsyncWrapperSQLiteValidation:
|
||||
|
||||
@pytest.fixture
|
||||
def test_env_setup(self, request):
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "0",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self, test_env_setup):
|
||||
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
|
||||
yield db_path
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def simple_async_add(a: int, b: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return a + b
|
||||
|
||||
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
|
||||
result = await simple_async_add(5, 3)
|
||||
|
||||
assert result == 8
|
||||
|
||||
assert temp_db_path.exists()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
|
||||
assert cur.fetchone() is not None
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
|
||||
(test_module_path, test_class_name, test_function_name, function_getting_tested,
|
||||
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
|
||||
|
||||
assert test_module_path == __name__
|
||||
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
|
||||
assert test_function_name == "test_behavior_async_basic_function"
|
||||
assert function_getting_tested == "simple_async_add"
|
||||
assert loop_index == 1
|
||||
# Line ID will be the actual line number from the source code, not a simple counter
|
||||
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
|
||||
assert runtime > 0
|
||||
assert verification_type == VerificationType.FUNCTION_CALL.value
|
||||
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, return_val = unpickled_data
|
||||
|
||||
assert args == (5, 3)
|
||||
assert kwargs == {}
|
||||
assert return_val == 8
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_divide(a: int, b: int) -> float:
|
||||
await asyncio.sleep(0.001)
|
||||
if b == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
return a / b
|
||||
|
||||
result = await async_divide(10, 2)
|
||||
assert result == 5.0
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot divide by zero"):
|
||||
await async_divide(10, 0)
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 2
|
||||
|
||||
success_row = rows[0]
|
||||
success_data = pickle.loads(success_row[7]) # return_value_blob
|
||||
args, kwargs, return_val = success_data
|
||||
assert args == (10, 2)
|
||||
assert return_val == 5.0
|
||||
|
||||
# Check exception record
|
||||
exception_row = rows[1]
|
||||
exception_data = pickle.loads(exception_row[7]) # return_value_blob
|
||||
assert isinstance(exception_data, ValueError)
|
||||
assert str(exception_data) == "Cannot divide by zero"
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
|
||||
"""Test performance async decorator doesn't store to database."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_multiply(a: int, b: int) -> int:
|
||||
"""Async function for performance testing."""
|
||||
await asyncio.sleep(0.002)
|
||||
return a * b
|
||||
|
||||
result = await async_multiply(4, 7)
|
||||
|
||||
assert result == 28
|
||||
|
||||
assert not temp_db_path.exists()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lines = captured.out.strip().split('\n')
|
||||
|
||||
assert len([line for line in output_lines if "!$######" in line]) == 1
|
||||
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
|
||||
|
||||
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
|
||||
assert "async_multiply" in closing_tag
|
||||
|
||||
timing_part = closing_tag.split(":")[-1].replace("######!", "")
|
||||
timing_value = int(timing_part)
|
||||
assert timing_value > 0 # Should have positive timing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_increment(value: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return value + 1
|
||||
|
||||
# Call the function multiple times
|
||||
results = []
|
||||
for i in range(3):
|
||||
result = await async_increment(i)
|
||||
results.append(result)
|
||||
|
||||
assert results == [1, 2, 3]
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
actual_ids = [row[0] for row in rows]
|
||||
assert len(actual_ids) == 3
|
||||
|
||||
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
|
||||
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
|
||||
assert actual_ids == expected_pattern
|
||||
|
||||
for i, (_, return_value_blob) in enumerate(rows):
|
||||
args, kwargs, return_val = pickle.loads(return_value_blob)
|
||||
assert args == (i,)
|
||||
assert return_val == i + 1
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def complex_async_func(
|
||||
pos_arg: str,
|
||||
*args: int,
|
||||
keyword_arg: str = "default",
|
||||
**kwargs: str
|
||||
) -> dict:
|
||||
await asyncio.sleep(0.001)
|
||||
return {
|
||||
"pos_arg": pos_arg,
|
||||
"args": args,
|
||||
"keyword_arg": keyword_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
result = await complex_async_func(
|
||||
"hello",
|
||||
1, 2, 3,
|
||||
keyword_arg="custom",
|
||||
extra1="value1",
|
||||
extra2="value2"
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
"pos_arg": "hello",
|
||||
"args": (1, 2, 3),
|
||||
"keyword_arg": "custom",
|
||||
"kwargs": {"extra1": "value1", "extra2": "value2"}
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM test_results")
|
||||
row = cur.fetchone()
|
||||
|
||||
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
|
||||
|
||||
assert stored_args == ("hello", 1, 2, 3)
|
||||
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
|
||||
assert stored_result == expected_result
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def schema_test_func() -> str:
|
||||
return "schema_test"
|
||||
|
||||
await schema_test_func()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("PRAGMA table_info(test_results)")
|
||||
columns = cur.fetchall()
|
||||
|
||||
expected_columns = [
|
||||
(0, 'test_module_path', 'TEXT', 0, None, 0),
|
||||
(1, 'test_class_name', 'TEXT', 0, None, 0),
|
||||
(2, 'test_function_name', 'TEXT', 0, None, 0),
|
||||
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
|
||||
(4, 'loop_index', 'INTEGER', 0, None, 0),
|
||||
(5, 'iteration_id', 'TEXT', 0, None, 0),
|
||||
(6, 'runtime', 'INTEGER', 0, None, 0),
|
||||
(7, 'return_value', 'BLOB', 0, None, 0),
|
||||
(8, 'verification_type', 'TEXT', 0, None, 0)
|
||||
]
|
||||
|
||||
assert columns == expected_columns
|
||||
con.close()
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
|
||||
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -1793,9 +1793,10 @@ def get_system_details():
|
|||
|
||||
# Set up the optimizer
|
||||
file_path = main_file_path.resolve()
|
||||
project_root = package_dir.resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=package_dir.resolve(),
|
||||
project_root=project_root,
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
|
|
@ -1819,6 +1820,8 @@ def get_system_details():
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# The expected contexts
|
||||
# Resolve both paths to handle symlink issues on macOS
|
||||
relative_path = file_path.relative_to(project_root)
|
||||
expected_read_write_context = f"""
|
||||
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
|
||||
import utility_module
|
||||
|
|
@ -2038,9 +2041,10 @@ def get_system_details():
|
|||
|
||||
# Set up the optimizer
|
||||
file_path = main_file_path.resolve()
|
||||
project_root = package_dir.resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=package_dir.resolve(),
|
||||
project_root=project_root,
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
|
|
@ -2063,6 +2067,7 @@ def get_system_details():
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# The expected contexts
|
||||
relative_path = file_path.relative_to(project_root)
|
||||
expected_read_write_context = f"""
|
||||
```python:utility_module.py
|
||||
# Function that will be used in the main code
|
||||
|
|
@ -2463,3 +2468,148 @@ def test_circular_deps():
|
|||
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
||||
|
||||
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|
||||
def test_global_assignment_collector_with_async_function():
|
||||
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global assignment
|
||||
GLOBAL_VAR = "global_value"
|
||||
OTHER_GLOBAL = 42
|
||||
|
||||
async def async_function():
|
||||
# This should not be collected (inside async function)
|
||||
local_var = "local_value"
|
||||
INNER_ASSIGNMENT = "should_not_be_global"
|
||||
return local_var
|
||||
|
||||
# Another global assignment
|
||||
ANOTHER_GLOBAL = "another_global"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalAssignmentCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect global assignments but not the ones inside async function
|
||||
assert len(collector.assignments) == 3
|
||||
assert "GLOBAL_VAR" in collector.assignments
|
||||
assert "OTHER_GLOBAL" in collector.assignments
|
||||
assert "ANOTHER_GLOBAL" in collector.assignments
|
||||
|
||||
# Should not collect assignments from inside async function
|
||||
assert "local_var" not in collector.assignments
|
||||
assert "INNER_ASSIGNMENT" not in collector.assignments
|
||||
|
||||
# Verify assignment order
|
||||
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
|
||||
assert collector.assignment_order == expected_order
|
||||
|
||||
|
||||
def test_global_assignment_collector_nested_async_functions():
|
||||
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global assignment
|
||||
CONFIG = {"key": "value"}
|
||||
|
||||
def sync_function():
|
||||
# Inside sync function - should not be collected
|
||||
sync_local = "sync"
|
||||
|
||||
async def nested_async():
|
||||
# Inside nested async function - should not be collected
|
||||
nested_var = "nested"
|
||||
return nested_var
|
||||
|
||||
return sync_local
|
||||
|
||||
async def async_function():
|
||||
# Inside async function - should not be collected
|
||||
async_local = "async"
|
||||
|
||||
def nested_sync():
|
||||
# Inside nested function - should not be collected
|
||||
deeply_nested = "deep"
|
||||
return deeply_nested
|
||||
|
||||
return async_local
|
||||
|
||||
# Another global assignment
|
||||
FINAL_GLOBAL = "final"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalAssignmentCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should only collect global-level assignments
|
||||
assert len(collector.assignments) == 2
|
||||
assert "CONFIG" in collector.assignments
|
||||
assert "FINAL_GLOBAL" in collector.assignments
|
||||
|
||||
# Should not collect any assignments from inside functions
|
||||
assert "sync_local" not in collector.assignments
|
||||
assert "nested_var" not in collector.assignments
|
||||
assert "async_local" not in collector.assignments
|
||||
assert "deeply_nested" not in collector.assignments
|
||||
|
||||
|
||||
def test_global_assignment_collector_mixed_async_sync_with_classes():
|
||||
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global assignments
|
||||
GLOBAL_CONSTANT = "constant"
|
||||
|
||||
class TestClass:
|
||||
# Class-level assignment - should not be collected
|
||||
class_var = "class_value"
|
||||
|
||||
def sync_method(self):
|
||||
# Method assignment - should not be collected
|
||||
method_var = "method"
|
||||
return method_var
|
||||
|
||||
async def async_method(self):
|
||||
# Async method assignment - should not be collected
|
||||
async_method_var = "async_method"
|
||||
return async_method_var
|
||||
|
||||
def sync_function():
|
||||
# Function assignment - should not be collected
|
||||
func_var = "function"
|
||||
return func_var
|
||||
|
||||
async def async_function():
|
||||
# Async function assignment - should not be collected
|
||||
async_func_var = "async_function"
|
||||
return async_func_var
|
||||
|
||||
# More global assignments
|
||||
ANOTHER_CONSTANT = 100
|
||||
FINAL_ASSIGNMENT = {"data": "value"}
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalAssignmentCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should only collect global-level assignments
|
||||
assert len(collector.assignments) == 3
|
||||
assert "GLOBAL_CONSTANT" in collector.assignments
|
||||
assert "ANOTHER_CONSTANT" in collector.assignments
|
||||
assert "FINAL_ASSIGNMENT" in collector.assignments
|
||||
|
||||
# Should not collect assignments from inside any scoped blocks
|
||||
assert "class_var" not in collector.assignments
|
||||
assert "method_var" not in collector.assignments
|
||||
assert "async_method_var" not in collector.assignments
|
||||
assert "func_var" not in collector.assignments
|
||||
assert "async_func_var" not in collector.assignments
|
||||
|
||||
# Verify correct order
|
||||
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
|
||||
assert collector.assignment_order == expected_order
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from codeflash.code_utils.code_replacer import (
|
|||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
replace_functions_in_file,
|
||||
OptimFunctionCollector,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
|
||||
|
|
@ -3448,156 +3449,173 @@ def hydrate_input_text_actions_with_field_names(
|
|||
|
||||
assert new_code == expected
|
||||
|
||||
def test_duplicate_global_assignments_when_reverting_helpers():
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
The pre-chunker's responsibilities are:
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# -- all detectors need to be called to update state and avoid double counting
|
||||
# -- boundaries that happen to coincide, like Table and new section on same element.
|
||||
# -- Using `any()` would short-circuit on first True.
|
||||
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
|
||||
return any(semantic_boundaries)
|
||||
'''
|
||||
main_file.write_text(original_code, encoding="utf-8")
|
||||
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
from __future__ import annotations
|
||||
from typing import Iterable
|
||||
from unstructured.documents.elements import Element
|
||||
from unstructured.utils import lazyproperty
|
||||
class PreChunker:
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
for pred in self._boundary_predicates:
|
||||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
```
|
||||
'''
|
||||
# OptimFunctionCollector async function tests
|
||||
def test_optim_function_collector_with_async_functions():
|
||||
"""Test OptimFunctionCollector correctly collects async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def sync_function():
|
||||
return "sync"
|
||||
|
||||
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
|
||||
test_config = TestConfig(
|
||||
tests_root=root_dir / "tests/pytest",
|
||||
tests_project_rootdir=root_dir,
|
||||
project_root_path=root_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
async def async_function():
|
||||
return "async"
|
||||
|
||||
class TestClass:
|
||||
def sync_method(self):
|
||||
return "sync_method"
|
||||
|
||||
async def async_method(self):
|
||||
return "async_method"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
|
||||
preexisting_objects=None
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect both sync and async functions
|
||||
assert len(collector.modified_functions) == 4
|
||||
assert (None, "sync_function") in collector.modified_functions
|
||||
assert (None, "async_function") in collector.modified_functions
|
||||
assert ("TestClass", "sync_method") in collector.modified_functions
|
||||
assert ("TestClass", "async_method") in collector.modified_functions
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
||||
def test_optim_function_collector_new_async_functions():
|
||||
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def existing_function():
|
||||
return "existing"
|
||||
|
||||
async def new_async_function():
|
||||
return "new_async"
|
||||
|
||||
def new_sync_function():
|
||||
return "new_sync"
|
||||
|
||||
class ExistingClass:
|
||||
async def new_class_async_method(self):
|
||||
return "new_class_async"
|
||||
"""
|
||||
|
||||
# Only existing_function is in preexisting objects
|
||||
preexisting_objects = {("existing_function", ())}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=set(), # Not looking for specific functions
|
||||
preexisting_objects=preexisting_objects
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should identify new functions (both sync and async)
|
||||
assert len(collector.new_functions) == 2
|
||||
function_names = [func.name.value for func in collector.new_functions]
|
||||
assert "new_async_function" in function_names
|
||||
assert "new_sync_function" in function_names
|
||||
|
||||
# Should identify new class methods
|
||||
assert "ExistingClass" in collector.new_class_functions
|
||||
assert len(collector.new_class_functions["ExistingClass"]) == 1
|
||||
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
|
||||
|
||||
|
||||
new_code = main_file.read_text(encoding="utf-8")
|
||||
main_file.unlink(missing_ok=True)
|
||||
def test_optim_function_collector_mixed_scenarios():
|
||||
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global functions
|
||||
def global_sync():
|
||||
pass
|
||||
|
||||
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
The pre-chunker's responsibilities are:
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
for pred in self._boundary_predicates:
|
||||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
async def global_async():
|
||||
pass
|
||||
|
||||
class ParentClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
pass
|
||||
|
||||
async def async_method(self):
|
||||
pass
|
||||
|
||||
class ChildClass:
|
||||
async def child_async_method(self):
|
||||
pass
|
||||
|
||||
def child_sync_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
# Looking for specific functions
|
||||
function_names = {
|
||||
(None, "global_sync"),
|
||||
(None, "global_async"),
|
||||
("ParentClass", "sync_method"),
|
||||
("ParentClass", "async_method"),
|
||||
("ChildClass", "child_async_method")
|
||||
}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=function_names,
|
||||
preexisting_objects=None
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect all specified functions (mix of sync and async)
|
||||
assert len(collector.modified_functions) == 5
|
||||
assert (None, "global_sync") in collector.modified_functions
|
||||
assert (None, "global_async") in collector.modified_functions
|
||||
assert ("ParentClass", "sync_method") in collector.modified_functions
|
||||
assert ("ParentClass", "async_method") in collector.modified_functions
|
||||
assert ("ChildClass", "child_async_method") in collector.modified_functions
|
||||
|
||||
# Should collect __init__ method
|
||||
assert "ParentClass" in collector.modified_init_functions
|
||||
|
||||
|
||||
|
||||
def test_is_zero_diff_async_sleep():
|
||||
original_code = '''
|
||||
import time
|
||||
|
||||
async def task():
|
||||
time.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert new_code == expected
|
||||
optimized_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert not is_zero_diff(original_code, optimized_code)
|
||||
|
||||
def test_is_zero_diff_with_equivalent_code():
|
||||
original_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
optimized_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
"""A task that does something."""
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert is_zero_diff(original_code, optimized_code)
|
||||
|
|
@ -17,10 +17,10 @@ from codeflash.code_utils.code_utils import (
|
|||
is_class_defined_in_file,
|
||||
module_name_from_file_path,
|
||||
path_belongs_to_site_packages,
|
||||
has_any_async_functions,
|
||||
validate_python_code,
|
||||
)
|
||||
from codeflash.code_utils.concolic_utils import clean_concolic_tests
|
||||
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
|
||||
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -368,6 +368,86 @@ def my_function():
|
|||
assert is_class_defined_in_file("MyClass", test_file) is False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_code_context():
|
||||
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
|
||||
from unittest.mock import MagicMock
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
context = MagicMock(spec=CodeOptimizationContext)
|
||||
context.preexisting_objects = []
|
||||
return context
|
||||
|
||||
|
||||
def test_extract_dependent_function_sync_and_async(mock_code_context):
|
||||
"""Test extract_dependent_function with both sync and async functions."""
|
||||
# Test sync function extraction
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
def helper_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
|
||||
|
||||
# Test async function extraction
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
async def async_helper_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
|
||||
|
||||
|
||||
def test_extract_dependent_function_edge_cases(mock_code_context):
|
||||
"""Test extract_dependent_function edge cases."""
|
||||
# No dependent functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
# Multiple dependent functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
def helper1():
|
||||
pass
|
||||
|
||||
async def helper2():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
|
||||
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
|
||||
"""Test extract_dependent_function with mixed sync/async scenarios."""
|
||||
# Async main with sync helper
|
||||
mock_code_context.testgen_context_code = """
|
||||
async def async_main():
|
||||
pass
|
||||
|
||||
def sync_helper():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
|
||||
|
||||
# Only async functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
async def async_main():
|
||||
pass
|
||||
|
||||
async def async_helper():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
|
||||
|
||||
|
||||
def test_is_class_defined_in_file_with_non_existing_file() -> None:
|
||||
non_existing_file = Path("/non/existing/file.py")
|
||||
|
||||
|
|
@ -505,25 +585,41 @@ def test_Grammar_copy():
|
|||
assert cleaned_code == expected_cleaned_code.strip()
|
||||
|
||||
|
||||
def test_has_any_async_functions_with_async_code() -> None:
|
||||
def test_validate_python_code_valid() -> None:
|
||||
code = "def hello():\n return 'world'"
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
||||
|
||||
def test_validate_python_code_invalid() -> None:
|
||||
code = "def hello(:\n return 'world'"
|
||||
with pytest.raises(ValueError, match="Invalid Python code"):
|
||||
validate_python_code(code)
|
||||
|
||||
|
||||
def test_validate_python_code_empty() -> None:
|
||||
code = ""
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
||||
|
||||
def test_validate_python_code_complex_invalid() -> None:
|
||||
code = "if True\n print('missing colon')"
|
||||
with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"):
|
||||
validate_python_code(code)
|
||||
|
||||
|
||||
def test_validate_python_code_valid_complex() -> None:
|
||||
code = """
|
||||
def normal_function():
|
||||
pass
|
||||
|
||||
async def async_function():
|
||||
pass
|
||||
def calculate(a, b):
|
||||
if a > b:
|
||||
return a + b
|
||||
else:
|
||||
return a * b
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
"""
|
||||
result = has_any_async_functions(code)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_has_any_async_functions_without_async_code() -> None:
|
||||
code = """
|
||||
def normal_function():
|
||||
pass
|
||||
|
||||
def another_function():
|
||||
pass
|
||||
"""
|
||||
result = has_any_async_functions(code)
|
||||
assert result is False
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
|
|
|||
|
|
@ -14,7 +14,13 @@ from codeflash.models.models import (
|
|||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.critic import (
|
||||
coverage_critic,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
throughput_gain,
|
||||
)
|
||||
|
||||
|
||||
def test_performance_gain() -> None:
|
||||
|
|
@ -429,3 +435,159 @@ def test_coverage_critic() -> None:
|
|||
)
|
||||
|
||||
assert coverage_critic(unittest_coverage, "unittest") is True
|
||||
|
||||
|
||||
def test_throughput_gain() -> None:
|
||||
"""Test throughput_gain calculation."""
|
||||
# Test basic throughput improvement
|
||||
assert throughput_gain(original_throughput=100, optimized_throughput=150) == 0.5 # 50% improvement
|
||||
|
||||
# Test no improvement
|
||||
assert throughput_gain(original_throughput=100, optimized_throughput=100) == 0.0
|
||||
|
||||
# Test regression
|
||||
assert throughput_gain(original_throughput=100, optimized_throughput=80) == -0.2 # 20% regression
|
||||
|
||||
# Test zero original throughput (edge case)
|
||||
assert throughput_gain(original_throughput=0, optimized_throughput=50) == 0.0
|
||||
|
||||
# Test large improvement
|
||||
assert throughput_gain(original_throughput=50, optimized_throughput=200) == 3.0 # 300% improvement
|
||||
|
||||
|
||||
def test_speedup_critic_with_async_throughput() -> None:
|
||||
"""Test speedup_critic with async throughput evaluation."""
|
||||
original_code_runtime = 10000 # 10 microseconds
|
||||
original_async_throughput = 100
|
||||
|
||||
# Test case 1: Both runtime and throughput improve significantly
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=120, # 20% throughput improvement
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=105, # Only 5% throughput improvement (below 10% threshold)
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=9800, # Only 2% runtime improvement (below 5% threshold)
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=9800,
|
||||
async_throughput=120, # 20% throughput improvement
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 4: No throughput data - should fall back to runtime-only evaluation
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=None, # No throughput data
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=None, # No original throughput data
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 5: Test best_throughput_until_now comparison
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=115, # 15% throughput improvement
|
||||
)
|
||||
|
||||
# Should pass when no best throughput yet
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Should fail when there's a better throughput already
|
||||
assert not speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=7000, # Better runtime already exists
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=120, # Better throughput already exists
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 6: Zero original throughput (edge case)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=50,
|
||||
)
|
||||
|
||||
# Should pass when original throughput is 0 (throughput evaluation skipped)
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=0, # Zero original throughput
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
|
|
|||
793
tests/test_instrument_async_tests.py
Normal file
793
tests/test_instrument_async_tests.py
Normal file
|
|
@ -0,0 +1,793 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
add_async_decorator_to_function,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, TestingMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def unique_test_iteration():
|
||||
# """Provide a unique test iteration ID and clean up database after test."""
|
||||
# # Generate unique iteration ID
|
||||
# iteration_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# # Store original environment variable
|
||||
# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION")
|
||||
|
||||
# # Set unique iteration for this test
|
||||
# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id
|
||||
|
||||
# try:
|
||||
# yield iteration_id
|
||||
# finally:
|
||||
# # Cleanup: restore original environment and delete database file
|
||||
# if original_iteration is not None:
|
||||
# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration
|
||||
# elif "CODEFLASH_TEST_ITERATION" in os.environ:
|
||||
# del os.environ["CODEFLASH_TEST_ITERATION"]
|
||||
|
||||
# # Clean up database file
|
||||
# try:
|
||||
# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite"))
|
||||
# if db_path.exists():
|
||||
# db_path.unlink()
|
||||
# except Exception:
|
||||
# pass # Ignore cleanup errors
|
||||
|
||||
|
||||
def test_async_decorator_application_behavior_mode():
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
def test_async_decorator_application_performance_mode():
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_performance_async
|
||||
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
def test_async_class_method_decorator_application():
|
||||
async_class_code = '''
|
||||
import asyncio
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_method",
|
||||
file_path=Path("test_async.py"),
|
||||
parents=[{"name": "Calculator", "type": "ClassDef"}],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
def test_async_decorator_no_duplicate_application():
|
||||
already_decorated_code = '''
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
import asyncio
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_reformatted_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert not decorator_added
|
||||
assert modified_code.strip() == expected_reformatted_code.strip()
|
||||
|
||||
|
||||
def test_inject_profiling_async_function_behavior_mode(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
async_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
"""Test async function behavior."""
|
||||
result = await async_function(5, 3)
|
||||
assert result == 15
|
||||
|
||||
result2 = await async_function(2, 4)
|
||||
assert result2 == 8
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_test_code)
|
||||
|
||||
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
|
||||
|
||||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success is True
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# For async functions, once source is decorated, test injection should fail
|
||||
# This is expected behavior - async instrumentation happens at the decorator level
|
||||
assert success is False
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
def test_inject_profiling_async_function_performance_mode(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
# Create the test file
|
||||
async_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
"""Test async function performance."""
|
||||
result = await async_function(5, 3)
|
||||
assert result == 15
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_test_code)
|
||||
|
||||
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
|
||||
|
||||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
assert source_success is True
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_performance_async" in instrumented_source
|
||||
# Check for the import with line continuation formatting
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_performance_async" in instrumented_source
|
||||
|
||||
# Write the instrumented source back
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
# Now test the full pipeline with source module path
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
# For async functions, once source is decorated, test injection should fail
|
||||
# This is expected behavior - async instrumentation happens at the decorator level
|
||||
assert success is False
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
def test_mixed_sync_async_instrumentation(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
def sync_function(x: int, y: int) -> int:
|
||||
"""Regular sync function."""
|
||||
return x * y
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
mixed_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import sync_function, async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_functions():
|
||||
"""Test both sync and async functions."""
|
||||
sync_result = sync_function(10, 5)
|
||||
assert sync_result == 50
|
||||
|
||||
async_result = await async_function(3, 4)
|
||||
assert async_result == 12
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_mixed.py"
|
||||
test_file.write_text(mixed_test_code)
|
||||
|
||||
async_func = FunctionToOptimize(
|
||||
function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True
|
||||
)
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, async_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
# Sync function should remain unchanged
|
||||
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
|
||||
|
||||
# Write instrumented source back
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file,
|
||||
[CodePosition(8, 18), CodePosition(11, 19)],
|
||||
async_func,
|
||||
temp_dir,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
|
||||
# Async functions should not be instrumented at the test level
|
||||
assert not success
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
def test_async_function_qualified_name_handling():
|
||||
nested_async_code = '''
|
||||
import asyncio
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
"""Nested async method."""
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="nested_async_method",
|
||||
file_path=Path("test_nested.py"),
|
||||
parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
expected_output = (
|
||||
"""import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
@codeflash_behavior_async
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
\"\"\"Nested async method.\"\"\"
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
"""
|
||||
)
|
||||
|
||||
assert modified_code.strip() == expected_output.strip()
|
||||
|
||||
|
||||
def test_async_decorator_with_existing_decorators():
|
||||
"""Test async decorator application when function already has other decorators."""
|
||||
decorated_async_code = '''
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
def my_decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@my_decorator
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Async function with existing decorator."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
# Should add codeflash decorator above existing decorators
|
||||
assert "@codeflash_behavior_async" in modified_code
|
||||
assert "@my_decorator" in modified_code
|
||||
# Codeflash decorator should come first
|
||||
codeflash_pos = modified_code.find("@codeflash_behavior_async")
|
||||
my_decorator_pos = modified_code.find("@my_decorator")
|
||||
assert codeflash_pos < my_decorator_pos
|
||||
|
||||
|
||||
def test_sync_function_not_affected_by_async_logic():
|
||||
sync_function_code = '''
|
||||
def sync_function(x: int, y: int) -> int:
|
||||
"""Regular sync function."""
|
||||
return x + y
|
||||
'''
|
||||
|
||||
sync_func = FunctionToOptimize(
|
||||
function_name="sync_function",
|
||||
file_path=Path("test_sync.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(
|
||||
sync_function_code, sync_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert not decorator_added
|
||||
assert modified_code == sync_function_code
|
||||
|
||||
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
|
||||
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_sorter(items):
|
||||
"""Simple async sorter for testing."""
|
||||
await asyncio.sleep(0.001)
|
||||
return sorted(items)
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "async_sorter.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
test_code_multiple_calls = """
|
||||
import asyncio
|
||||
import pytest
|
||||
from async_sorter import async_sorter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_call():
|
||||
result = await async_sorter([42])
|
||||
assert result == [42]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls():
|
||||
result1 = await async_sorter([3, 1, 2])
|
||||
result2 = await async_sorter([5, 4])
|
||||
result3 = await async_sorter([9, 8, 7, 6])
|
||||
assert result1 == [1, 2, 3]
|
||||
assert result2 == [4, 5]
|
||||
assert result3 == [6, 7, 8, 9]
|
||||
"""
|
||||
|
||||
test_file = temp_dir / "test_async_sorter.py"
|
||||
test_file.write_text(test_code_multiple_calls)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True
|
||||
)
|
||||
|
||||
# First instrument the source module with async decorators
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
import ast
|
||||
|
||||
tree = ast.parse(test_code_multiple_calls)
|
||||
call_positions = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
|
||||
if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or (
|
||||
hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter"
|
||||
):
|
||||
call_positions.append(CodePosition(node.lineno, node.col_offset))
|
||||
|
||||
assert len(call_positions) == 4
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert success
|
||||
assert instrumented_test_code is not None
|
||||
|
||||
assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code
|
||||
|
||||
# Count occurrences of each line_id to verify numbering
|
||||
line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'")
|
||||
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
|
||||
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
|
||||
|
||||
|
||||
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
|
||||
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
|
||||
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
|
||||
|
||||
|
||||
|
||||
def test_async_behavior_decorator_return_values_and_test_ids():
|
||||
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def test_async_multiply(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.001) # Small delay to simulate async work
|
||||
return x * y
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_module",
|
||||
"CODEFLASH_TEST_CLASS": None,
|
||||
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "0",
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "2",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
result = asyncio.run(test_async_multiply(6, 7))
|
||||
|
||||
assert result == 42, f"Expected return value 42, got {result}"
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
|
||||
|
||||
# Verify database exists and has data
|
||||
assert db_path.exists(), f"Database file not created at {db_path}"
|
||||
|
||||
# Read and verify database contents
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
|
||||
|
||||
row = rows[0]
|
||||
(
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) = row
|
||||
|
||||
assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'"
|
||||
assert test_class is None, f"Expected test_class None, got '{test_class}'"
|
||||
assert test_function == "test_async_multiply_function", (
|
||||
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "test_async_multiply", (
|
||||
f"Expected function_name 'test_async_multiply', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
|
||||
assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'"
|
||||
assert verification_type == "function_call", (
|
||||
f"Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, actual_return_value = unpickled_data
|
||||
|
||||
assert args == (6, 7), f"Expected args (6, 7), got {args}"
|
||||
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
|
||||
|
||||
assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}"
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
|
||||
def test_async_decorator_comprehensive_return_values_and_test_ids():
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
|
||||
"""Async function that multiplies x*y then adds z."""
|
||||
await asyncio.sleep(0.001)
|
||||
result = (x * y) + z
|
||||
return result
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
|
||||
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "3",
|
||||
"CODEFLASH_LOOP_INDEX": "2",
|
||||
"CODEFLASH_TEST_ITERATION": "3",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
test_cases = [
|
||||
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
|
||||
{"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18
|
||||
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_case in test_cases:
|
||||
result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"]))
|
||||
results.append(result)
|
||||
|
||||
# Verify each return value is exactly correct
|
||||
assert result == test_case["expected"], (
|
||||
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
|
||||
)
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
|
||||
assert db_path.exists(), f"Database not created at {db_path}"
|
||||
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute(
|
||||
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
|
||||
|
||||
for i, (
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) in enumerate(rows):
|
||||
assert test_module == "test_comprehensive_module", (
|
||||
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
|
||||
)
|
||||
assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
|
||||
assert test_function == "test_comprehensive_async_function", (
|
||||
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "async_multiply_add", (
|
||||
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}"
|
||||
assert verification_type == "function_call", (
|
||||
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
|
||||
expected_iteration_id = f"3_{i}"
|
||||
assert iteration_id == expected_iteration_id, (
|
||||
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
|
||||
)
|
||||
|
||||
|
||||
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
|
||||
expected_args = test_cases[i]["args"]
|
||||
expected_kwargs = test_cases[i]["kwargs"]
|
||||
expected_return = test_cases[i]["expected"]
|
||||
|
||||
assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}"
|
||||
assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
|
||||
assert actual_return_value == expected_return, (
|
||||
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
|
||||
)
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
|
@ -6,9 +6,11 @@ from pathlib import Path
|
|||
import pytest
|
||||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
from codeflash.context.unused_definition_remover import revert_unused_helper_functions
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -225,6 +227,15 @@ def helper_function_2(x):
|
|||
def test_no_unused_helpers_no_revert(temp_project):
|
||||
"""Test that when all helpers are still used, nothing is reverted."""
|
||||
temp_dir, main_file, test_cfg = temp_project
|
||||
|
||||
|
||||
# Store original content to verify nothing changes
|
||||
original_content = main_file.read_text()
|
||||
|
||||
revert_unused_helper_functions(temp_dir, [], {})
|
||||
|
||||
# Verify the file content remains unchanged
|
||||
assert main_file.read_text() == original_content, "File should remain unchanged when no helpers to revert"
|
||||
|
||||
# Optimized version that still calls both helpers
|
||||
optimized_code = """
|
||||
|
|
@ -308,17 +319,23 @@ def helper_function_1(x):
|
|||
def helper_function_2(x):
|
||||
\"\"\"Second helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def helper_function_1(y): # Duplicate name to test line 575
|
||||
\"\"\"Overloaded helper function.\"\"\"
|
||||
return y + 10
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one helper
|
||||
# Optimized version that only calls one helper with aliased import
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
from helpers import helper_function_1
|
||||
from helpers import helper_function_1 as h1
|
||||
import helpers as h_module
|
||||
|
||||
def entrypoint_function(n):
|
||||
\"\"\"Optimized function that only calls one helper.\"\"\"
|
||||
result1 = helper_function_1(n)
|
||||
return result1 + n * 3 # Inlined helper_function_2
|
||||
\"\"\"Optimized function that only calls one helper with aliasing.\"\"\"
|
||||
result1 = h1(n) # Using aliased import
|
||||
# Inlined helper_function_2 functionality: n * 3
|
||||
return result1 + n * 3 # Fully inlined helper_function_2
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -1462,113 +1479,44 @@ class MathUtils:
|
|||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_unused_helper_detection_with_duplicated_function_name_in_different_classes():
|
||||
"""Test detection when helpers are called via module.function style."""
|
||||
def test_async_entrypoint_with_async_helpers():
|
||||
"""Test that unused async helper functions are correctly detected when entrypoint is async."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file
|
||||
# Main file with async entrypoint and async helpers
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""from __future__ import annotations
|
||||
import json
|
||||
from helpers import replace_quotes_with_backticks, simplify_worktree_paths
|
||||
from dataclasses import asdict, dataclass
|
||||
main_file.write_text("""
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
@dataclass
|
||||
class LspMessage:
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def serialize(self) -> str:
|
||||
data = self._loop_through(asdict(self))
|
||||
# Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it
|
||||
ordered = {"type": self.type(), **data}
|
||||
return (
|
||||
message_delimiter
|
||||
+ json.dumps(ordered)
|
||||
+ message_delimiter
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMarkdownMessage(LspMessage):
|
||||
|
||||
def serialize(self) -> str:
|
||||
self.markdown = simplify_worktree_paths(self.markdown)
|
||||
self.markdown = replace_quotes_with_backticks(self.markdown)
|
||||
return super().serialize()
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
result2 = await async_helper_2(n)
|
||||
return result1 + result2
|
||||
""")
|
||||
|
||||
# Helpers file
|
||||
helpers_file = temp_dir / "helpers.py"
|
||||
helpers_file.write_text("""def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
path_in_msg = worktree_path_regex.search(msg)
|
||||
if path_in_msg:
|
||||
last_part_of_path = path_in_msg.group(0).split("/")[-1]
|
||||
if highlight:
|
||||
last_part_of_path = f"`{last_part_of_path}`"
|
||||
return msg.replace(path_in_msg.group(0), last_part_of_path)
|
||||
return msg
|
||||
|
||||
|
||||
def replace_quotes_with_backticks(text: str) -> str:
|
||||
# double-quoted strings
|
||||
text = _double_quote_pat.sub(r"`\1`", text)
|
||||
# single-quoted strings
|
||||
return _single_quote_pat.sub(r"`\1`", text)
|
||||
""")
|
||||
|
||||
# Optimized version that only uses add_numbers
|
||||
# Optimized version that only calls one async helper
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
from __future__ import annotations
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function - should be unused.\"\"\"
|
||||
return x * 3
|
||||
|
||||
from codeflash.lsp.helpers import (replace_quotes_with_backticks,
|
||||
simplify_worktree_paths)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMessage:
|
||||
|
||||
def serialize(self) -> str:
|
||||
# Use local variable to minimize lookup costs and avoid unnecessary dictionary unpacking
|
||||
data = self._loop_through(asdict(self))
|
||||
msg_type = self.type()
|
||||
ordered = {'type': msg_type}
|
||||
ordered.update(data)
|
||||
return (
|
||||
message_delimiter
|
||||
+ json.dumps(ordered)
|
||||
+ message_delimiter # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class LspMarkdownMessage(LspMessage):
|
||||
|
||||
def serialize(self) -> str:
|
||||
# Side effect required, must preserve for behavioral correctness
|
||||
self.markdown = simplify_worktree_paths(self.markdown)
|
||||
self.markdown = replace_quotes_with_backticks(self.markdown)
|
||||
return super().serialize()
|
||||
```
|
||||
```python:helpers.py
|
||||
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
m = worktree_path_regex.search(msg)
|
||||
if m:
|
||||
# More efficient way to get last path part
|
||||
last_part_of_path = m.group(0).rpartition('/')[-1]
|
||||
if highlight:
|
||||
last_part_of_path = f"`{last_part_of_path}`"
|
||||
return msg.replace(m.group(0), last_part_of_path)
|
||||
return msg
|
||||
|
||||
def replace_quotes_with_backticks(text: str) -> str:
|
||||
# Efficient string substitution, reduces intermediate string allocations
|
||||
return _single_quote_pat.sub(
|
||||
r"`\1`",
|
||||
_double_quote_pat.sub(r"`\1`", text),
|
||||
)
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
return result1 + n * 3 # Inlined async_helper_2
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -1581,31 +1529,543 @@ def replace_quotes_with_backticks(text: str) -> str:
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file, function_name="serialize", qualified_name="serialize", parents=[
|
||||
FunctionParent(name="LspMarkdownMessage", type="ClassDef"),
|
||||
]
|
||||
file_path=main_file,
|
||||
function_name="async_entrypoint",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect async_helper_2 as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
assert len(unused_names) == 0 # no unused helpers
|
||||
expected_unused = {"async_helper_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_sync_entrypoint_with_async_helpers():
|
||||
"""Test that unused async helper functions are detected when entrypoint is sync."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with sync entrypoint and async helpers
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
import asyncio
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_entrypoint(n):
|
||||
\"\"\"Sync entrypoint function that calls async helpers.\"\"\"
|
||||
result1 = asyncio.run(async_helper_1(n))
|
||||
result2 = asyncio.run(async_helper_2(n))
|
||||
return result1 + result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one async helper
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
import asyncio
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function - should be unused.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_entrypoint(n):
|
||||
\"\"\"Optimized sync entrypoint that only calls one async helper.\"\"\"
|
||||
result1 = asyncio.run(async_helper_1(n))
|
||||
return result1 + n * 3 # Inlined async_helper_2
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for sync function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="sync_entrypoint",
|
||||
parents=[]
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect async_helper_2 as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"async_helper_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_mixed_sync_and_async_helpers():
|
||||
"""Test detection when both sync and async helpers are mixed."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with mixed sync and async helpers
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
import asyncio
|
||||
|
||||
def sync_helper_1(x):
|
||||
\"\"\"Sync helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"Async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_2(x):
|
||||
\"\"\"Another sync helper function.\"\"\"
|
||||
return x * 4
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Another async helper function.\"\"\"
|
||||
return x * 5
|
||||
|
||||
async def mixed_entrypoint(n):
|
||||
\"\"\"Async entrypoint function that calls both sync and async helpers.\"\"\"
|
||||
sync_result = sync_helper_1(n)
|
||||
async_result = await async_helper_1(n)
|
||||
sync_result2 = sync_helper_2(n)
|
||||
async_result2 = await async_helper_2(n)
|
||||
return sync_result + async_result + sync_result2 + async_result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls some helpers
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
import asyncio
|
||||
|
||||
def sync_helper_1(x):
|
||||
\"\"\"Sync helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"Async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_2(x):
|
||||
\"\"\"Another sync helper function - should be unused.\"\"\"
|
||||
return x * 4
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Another async helper function - should be unused.\"\"\"
|
||||
return x * 5
|
||||
|
||||
async def mixed_entrypoint(n):
|
||||
\"\"\"Optimized async entrypoint that only calls some helpers.\"\"\"
|
||||
sync_result = sync_helper_1(n)
|
||||
async_result = await async_helper_1(n)
|
||||
return sync_result + async_result + n * 4 + n * 5 # Inlined both helper_2 functions
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="mixed_entrypoint",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect both sync_helper_2 and async_helper_2 as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"sync_helper_2", "async_helper_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_class_methods():
|
||||
"""Test unused async method detection in classes."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with class containing async methods
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
class AsyncProcessor:
|
||||
async def entrypoint_method(self, n):
|
||||
\"\"\"Async main method that calls async helper methods.\"\"\"
|
||||
result1 = await self.async_helper_method_1(n)
|
||||
result2 = await self.async_helper_method_2(n)
|
||||
return result1 + result2
|
||||
|
||||
async def async_helper_method_1(self, x):
|
||||
\"\"\"First async helper method.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_method_2(self, x):
|
||||
\"\"\"Second async helper method.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_method(self, x):
|
||||
\"\"\"Sync helper method.\"\"\"
|
||||
return x * 4
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one async helper
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
class AsyncProcessor:
|
||||
async def entrypoint_method(self, n):
|
||||
\"\"\"Optimized async method that only calls one helper.\"\"\"
|
||||
result1 = await self.async_helper_method_1(n)
|
||||
return result1 + n * 3 # Inlined async_helper_method_2
|
||||
|
||||
async def async_helper_method_1(self, x):
|
||||
\"\"\"First async helper method.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_method_2(self, x):
|
||||
\"\"\"Second async helper method - should be unused.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_method(self, x):
|
||||
\"\"\"Sync helper method - should be unused.\"\"\"
|
||||
return x * 4
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async class method
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="entrypoint_method",
|
||||
parents=[FunctionParent(name="AsyncProcessor", type="ClassDef")],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect async_helper_method_2 as unused (sync_helper_method may not be discovered as helper)
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"AsyncProcessor.async_helper_method_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_helper_revert_functionality():
|
||||
"""Test that unused async helper functions are correctly reverted to original definitions."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with async functions
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
result2 = await async_helper_2(n)
|
||||
return result1 + result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one helper and modifies the unused one
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Modified async helper function - should be reverted.\"\"\"
|
||||
return x * 10 # This change should be reverted
|
||||
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
return result1 + n * 3 # Inlined async_helper_2
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="async_entrypoint",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Store original helper code
|
||||
original_helper_code = {main_file: main_file.read_text()}
|
||||
|
||||
# Apply optimization and test reversion
|
||||
optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code
|
||||
)
|
||||
|
||||
# Check final file content
|
||||
final_content = main_file.read_text()
|
||||
|
||||
# The entrypoint should be optimized
|
||||
assert "result1 + n * 3" in final_content, "Async entrypoint function should be optimized"
|
||||
|
||||
# async_helper_2 should be reverted to original (return x * 3, not x * 10)
|
||||
assert "return x * 3" in final_content, "async_helper_2 should be reverted to original"
|
||||
assert "return x * 10" not in final_content, "async_helper_2 should not contain the modified version"
|
||||
|
||||
# async_helper_1 should remain (it's still called)
|
||||
assert "async def async_helper_1(x):" in final_content, "async_helper_1 should still exist"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_generators_and_coroutines():
|
||||
"""Test detection with async generators and coroutines."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with async generators and coroutines
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
import asyncio
|
||||
|
||||
async def async_generator_helper(n):
|
||||
\"\"\"Async generator helper.\"\"\"
|
||||
for i in range(n):
|
||||
yield i * 2
|
||||
|
||||
async def coroutine_helper(x):
|
||||
\"\"\"Coroutine helper.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 3
|
||||
|
||||
async def another_coroutine_helper(x):
|
||||
\"\"\"Another coroutine helper.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 4
|
||||
|
||||
async def async_entrypoint_with_generators(n):
|
||||
\"\"\"Async entrypoint function that uses generators and coroutines.\"\"\"
|
||||
results = []
|
||||
async for value in async_generator_helper(n):
|
||||
results.append(value)
|
||||
|
||||
final_result = await coroutine_helper(sum(results))
|
||||
another_result = await another_coroutine_helper(n)
|
||||
return final_result + another_result
|
||||
""")
|
||||
|
||||
# Optimized version that doesn't use one of the coroutines
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
import asyncio
|
||||
|
||||
async def async_generator_helper(n):
|
||||
\"\"\"Async generator helper.\"\"\"
|
||||
for i in range(n):
|
||||
yield i * 2
|
||||
|
||||
async def coroutine_helper(x):
|
||||
\"\"\"Coroutine helper.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 3
|
||||
|
||||
async def another_coroutine_helper(x):
|
||||
\"\"\"Another coroutine helper - should be unused.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 4
|
||||
|
||||
async def async_entrypoint_with_generators(n):
|
||||
\"\"\"Optimized async entrypoint that inlines one coroutine.\"\"\"
|
||||
results = []
|
||||
async for value in async_generator_helper(n):
|
||||
results.append(value)
|
||||
|
||||
final_result = await coroutine_helper(sum(results))
|
||||
return final_result + n * 4 # Inlined another_coroutine_helper
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="async_entrypoint_with_generators",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect another_coroutine_helper as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"another_coroutine_helper"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue