Merge branch 'main' of github.com:codeflash-ai/codeflash into testgen/multi-files

This commit is contained in:
ali 2025-10-08 01:40:56 +03:00
commit 1c7c2b88ba
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
86 changed files with 7953 additions and 2191 deletions

69
.github/workflows/e2e-async.yaml vendored Normal file
View file

@ -0,0 +1,69 @@
name: E2E - Async
on:
pull_request:
paths:
- '**' # Trigger for all paths
workflow_dispatch:
jobs:
async-optimization:
# Dynamically determine if environment is needed only when workflow files change and contributor is external
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 10
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
run: |
# Check for any workflow changes
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
# Get the PR author
AUTHOR="${{ github.event.pull_request.user.login }}"
echo "PR Author: $AUTHOR"
# Allowlist check
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($AUTHOR). Proceeding."
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
else
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v5
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: |
uv sync
- name: Run Codeflash to optimize async code
id: optimize_async_code
run: |
uv run python tests/scripts/end_to_end_test_async.py

View file

@ -20,7 +20,7 @@ jobs:
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 300
EXPECTED_IMPROVEMENT_PCT: 70
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout

View file

@ -20,7 +20,7 @@ jobs:
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 300
EXPECTED_IMPROVEMENT_PCT: 40
CODEFLASH_END_TO_END: 1
steps:
- name: 🛎️ Checkout

View file

@ -24,7 +24,6 @@ jobs:
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}
version: "0.5.30"
- name: install dependencies
run: uv sync

View file

@ -0,0 +1,30 @@
name: windows-unit-tests
on:
push:
branches: [main]
pull_request:
workflow_dispatch:
jobs:
windows-unit-tests:
continue-on-error: true
runs-on: windows-latest
env:
PYTHONIOENCODING: utf-8
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.13"
- name: install dependencies
run: uv sync
- name: Unit tests
run: uv run pytest tests/

View file

@ -0,0 +1,43 @@
import asyncio
from typing import List, Union
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
"""
Async bubble sort implementation for testing.
"""
print("codeflash stdout: Async sorting list")
await asyncio.sleep(0.01)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
print(f"result: {result}")
return result
class AsyncBubbleSorter:
"""Class with async sorting method for testing."""
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
"""
Async bubble sort implementation within a class.
"""
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
# Add some async delay
await asyncio.sleep(0.005)
n = len(lst)
for i in range(n):
for j in range(0, n - i - 1):
if lst[j] > lst[j + 1]:
lst[j], lst[j + 1] = lst[j + 1], lst[j]
result = lst.copy()
return result

View file

@ -0,0 +1,16 @@
import time
import asyncio
async def retry_with_backoff(func, max_retries=3):
if max_retries < 1:
raise ValueError("max_retries must be at least 1")
last_exception = None
for attempt in range(max_retries):
try:
return await func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
time.sleep(0.0001 * attempt)
raise last_exception

View file

@ -0,0 +1,6 @@
[tool.codeflash]
disable-telemetry = true
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
module-root = "."
test-framework = "pytest"
tests-root = "tests"

View file

@ -0,0 +1,8 @@
[tool.codeflash]
# All paths are relative to this pyproject.toml's directory.
module-root = "src/app"
tests-root = "src/tests"
test-framework = "pytest"
ignore-paths = []
disable-telemetry = true
formatter-cmds = ["disabled"]

View file

@ -0,0 +1,10 @@
def sorter(arr):
print("codeflash stdout: Sorting list")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr

View file

@ -102,6 +102,8 @@ class AiServiceClient:
trace_id: str,
num_candidates: int = 10,
experiment_metadata: ExperimentMetadata | None = None,
*,
is_async: bool = False,
) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint.
@ -133,6 +135,7 @@ class AiServiceClient:
"repo_owner": git_repo_owner,
"repo_name": git_repo_name,
"n_candidates": N_CANDIDATES_EFFECTIVE,
"is_async": is_async,
}
logger.info("!lsp|Generating optimized candidates…")
@ -299,6 +302,9 @@ class AiServiceClient:
annotated_tests: str,
optimization_id: str,
original_explanation: str,
original_throughput: str | None = None,
optimized_throughput: str | None = None,
throughput_improvement: str | None = None,
) -> str:
"""Optimize the given python code for performance by making a request to the Django endpoint.
@ -315,6 +321,9 @@ class AiServiceClient:
- annotated_tests: str - test functions annotated with runtime
- optimization_id: str - unique id of opt candidate
- original_explanation: str - original_explanation generated for the opt candidate
- original_throughput: str | None - throughput for the baseline code (operations per second)
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
- throughput_improvement: str | None - throughput improvement percentage
Returns
-------
@ -334,6 +343,9 @@ class AiServiceClient:
"optimization_id": optimization_id,
"original_explanation": original_explanation,
"dependency_code": dependency_code,
"original_throughput": original_throughput,
"optimized_throughput": optimized_throughput,
"throughput_improvement": throughput_improvement,
}
logger.info("loading|Generating explanation")
console.rule()
@ -488,6 +500,7 @@ class AiServiceClient:
"test_index": test_index,
"python_version": platform.python_version(),
"codeflash_version": codeflash_version,
"is_async": function_to_optimize.is_async,
}
try:
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)

View file

@ -26,7 +26,7 @@ if TYPE_CHECKING:
from packaging import version
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
CFAPI_BASE_URL = "http://localhost:3001"
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
console.rule()

View file

@ -4,6 +4,7 @@ import pickle
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any, Callable
from codeflash.picklepatch.pickle_patcher import PicklePatcher
@ -143,12 +144,13 @@ class CodeflashTrace:
print("Pickle limit reached")
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
normalized_file_path = Path(func.__code__.co_filename).as_posix()
self.function_calls_data.append(
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
normalized_file_path,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
@ -169,12 +171,13 @@ class CodeflashTrace:
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
normalized_file_path = Path(func.__code__.co_filename).as_posix()
self.function_calls_data.append(
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
normalized_file_path,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
@ -192,12 +195,13 @@ class CodeflashTrace:
# Add to the list of function calls with pickled args, to be used for replay tests
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
normalized_file_path = Path(func.__code__.co_filename).as_posix()
self.function_calls_data.append(
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
normalized_file_path,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,

View file

@ -30,19 +30,24 @@ def get_next_arg_and_return(
cur = db.cursor()
limit = num_to_get
normalized_file_path = Path(file_path).as_posix()
if class_name is not None:
cursor = cur.execute(
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
(benchmark_function_name, function_name, file_path, class_name, limit),
(benchmark_function_name, function_name, normalized_file_path, class_name, limit),
)
else:
cursor = cur.execute(
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
(benchmark_function_name, function_name, file_path, limit),
(benchmark_function_name, function_name, normalized_file_path, limit),
)
while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # pickled_args, pickled_kwargs
try:
while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # pickled_args, pickled_kwargs
finally:
db.close()
def get_function_alias(module: str, function_name: str) -> str:
@ -166,7 +171,7 @@ trace_file_path = r"{trace_file}"
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name")
file_path = func.get("file_path")
file_path = Path(func.get("file_path")).as_posix()
benchmark_function_name = func.get("benchmark_function_name")
function_properties = func.get("function_properties")
if not class_name:

View file

@ -1,3 +1,4 @@
import importlib.util
import logging
import sys
from argparse import SUPPRESS, ArgumentParser, Namespace
@ -96,6 +97,12 @@ def parse_args() -> Namespace:
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
parser.add_argument(
"--async",
default=False,
action="store_true",
help="Enable optimization of async functions. By default, async functions are excluded from optimization.",
)
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
@ -139,6 +146,14 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
if env_utils.is_ci():
args.no_pr = True
if getattr(args, "async", False) and importlib.util.find_spec("pytest_asyncio") is None:
logger.warning(
"Warning: The --async flag requires pytest-asyncio to be installed.\n"
"Please install it using:\n"
' pip install "codeflash[asyncio]"'
)
raise SystemExit(1)
return args

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import datetime
import json
import sys
import time
import uuid
from pathlib import Path
@ -11,13 +10,16 @@ from typing import TYPE_CHECKING, Any, Optional
from rich.prompt import Confirm
from codeflash.cli_cmds.console import console
from codeflash.code_utils.compat import codeflash_temp_dir
if TYPE_CHECKING:
import argparse
class CodeflashRunCheckpoint:
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None: # noqa: S108
def __init__(self, module_root: Path, checkpoint_dir: Path | None = None) -> None:
if checkpoint_dir is None:
checkpoint_dir = codeflash_temp_dir
self.module_root = module_root
self.checkpoint_dir = Path(checkpoint_dir)
# Create a unique checkpoint file name
@ -37,7 +39,7 @@ class CodeflashRunCheckpoint:
"last_updated": time.time(),
}
with self.checkpoint_path.open("w") as f:
with self.checkpoint_path.open("w", encoding="utf-8") as f:
f.write(json.dumps(metadata) + "\n")
def add_function_to_checkpoint(
@ -66,7 +68,7 @@ class CodeflashRunCheckpoint:
**additional_info,
}
with self.checkpoint_path.open("a") as f:
with self.checkpoint_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(function_data) + "\n")
# Update the metadata last_updated timestamp
@ -75,7 +77,7 @@ class CodeflashRunCheckpoint:
def _update_metadata_timestamp(self) -> None:
"""Update the last_updated timestamp in the metadata."""
# Read the first line (metadata)
with self.checkpoint_path.open() as f:
with self.checkpoint_path.open(encoding="utf-8") as f:
metadata = json.loads(f.readline())
rest_content = f.read()
@ -84,7 +86,7 @@ class CodeflashRunCheckpoint:
# Write all lines to a temporary file
with self.checkpoint_path.open("w") as f:
with self.checkpoint_path.open("w", encoding="utf-8") as f:
f.write(json.dumps(metadata) + "\n")
f.write(rest_content)
@ -94,7 +96,7 @@ class CodeflashRunCheckpoint:
self.checkpoint_path.unlink(missing_ok=True)
for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
with file.open() as f:
with file.open(encoding="utf-8") as f:
# Skip the first line (metadata)
first_line = next(f)
metadata = json.loads(first_line)
@ -116,7 +118,7 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
to_delete = []
for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
with file.open() as f:
with file.open(encoding="utf-8") as f:
# Skip the first line (metadata)
first_line = next(f)
metadata = json.loads(first_line)
@ -139,8 +141,8 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
previous_checkpoint_functions = None
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir(): # noqa: S108 #TODO: use the temp dir from codeutils-compat.py
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp")) # noqa: S108
if args.all and codeflash_temp_dir.is_dir():
previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir)
if previous_checkpoint_functions and Confirm.ask(
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
default=True,

View file

@ -272,6 +272,8 @@ class DottedImportCollector(cst.CSTVisitor):
if child.module is None:
continue
module = self.get_full_dotted_name(child.module)
if isinstance(child.names, cst.ImportStar):
continue
for alias in child.names:
if isinstance(alias, cst.ImportAlias):
name = alias.name.value
@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
return transformed_module.code
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
try:
module_path = module_name.replace(".", "/")
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
module_file = None
for path in possible_paths:
if path.exists():
module_file = path
break
if module_file is None:
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
return set()
with module_file.open(encoding="utf8") as f:
module_code = f.read()
tree = ast.parse(module_code)
all_names = None
for node in ast.walk(tree):
if (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "__all__"
):
if isinstance(node.value, (ast.List, ast.Tuple)):
all_names = []
for elt in node.value.elts:
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
all_names.append(elt.value)
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
all_names.append(elt.s)
break
if all_names is not None:
return set(all_names)
public_names = set()
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
if not node.name.startswith("_"):
public_names.add(node.name)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and not target.id.startswith("_"):
public_names.add(target.id)
elif isinstance(node, ast.AnnAssign):
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
public_names.add(node.target.id)
elif isinstance(node, ast.Import) or (
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
):
for alias in node.names:
name = alias.asname or alias.name
if not name.startswith("_"):
public_names.add(name)
return public_names # noqa: TRY300
except Exception as e:
logger.warning(f"Error resolving star import for {module_name}: {e}")
return set()
def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
@ -468,9 +537,23 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
# Handle star imports by resolving them to actual symbol names
if obj == "*":
resolved_symbols = resolve_star_import(mod, project_root)
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
for symbol in resolved_symbols:
if (
f"{mod}.{symbol}" not in helper_functions_fqn
and f"{mod}.{symbol}" not in dotted_import_collector.imports
):
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
else:
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code

View file

@ -173,14 +173,14 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
try:
relative_path = file_path.relative_to(project_root_path)
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
return relative_path.with_suffix("").as_posix().replace("/", ".")
except ValueError:
if traverse_up:
parent = file_path.parent
while parent not in (project_root_path, parent.parent):
try:
relative_path = file_path.relative_to(parent)
relative_path = file_path.resolve().relative_to(parent.resolve())
return relative_path.with_suffix("").as_posix().replace("/", ".")
except ValueError:
parent = parent.parent
@ -245,8 +245,9 @@ def get_run_tmp_file(file_path: Path) -> Path:
def path_belongs_to_site_packages(file_path: Path) -> bool:
site_packages = [Path(p) for p in site.getsitepackages()]
return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages)
file_path_resolved = file_path.resolve()
site_packages = [Path(p).resolve() for p in site.getsitepackages()]
return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages)
def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
@ -268,14 +269,6 @@ def validate_python_code(code: str) -> str:
return code
def has_any_async_functions(code: str) -> bool:
try:
module = ast.parse(code)
except SyntaxError:
return False
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))
def cleanup_paths(paths: list[Path]) -> None:
for path in paths:
if path and path.exists():

View file

@ -0,0 +1,167 @@
from __future__ import annotations
import asyncio
import gc
import os
import sqlite3
from enum import Enum
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Callable, TypeVar
import dill as pickle
class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py
FUNCTION_CALL = (
"function_call" # Correctness verification for a test function, checks input values and output values)
)
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
F = TypeVar("F", bound=Callable[..., Any])
def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
def extract_test_context_from_env() -> tuple[str, str | None, str]:
test_module = os.environ["CODEFLASH_TEST_MODULE"]
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
if test_module and test_function:
return (test_module, test_class if test_class else None, test_function)
raise RuntimeError(
"Test context environment variables not set - ensure tests are run through codeflash test runner"
)
def codeflash_behavior_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
codeflash_con = sqlite3.connect(db_path)
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs) # coroutine creation has some overhead, though it is very small
counter = loop.time()
return_value = await ret # let's measure the actual execution time of the code
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
codeflash_duration,
pickled_return_value,
VerificationType.FUNCTION_CALL.value,
),
)
codeflash_con.commit()
codeflash_con.close()
if exception:
raise exception
return return_value
return async_wrapper
def codeflash_performance_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
test_module_name, test_class_name, test_name = extract_test_context_from_env()
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(async_wrapper, "index"):
async_wrapper.index = {}
if test_id in async_wrapper.index:
async_wrapper.index[test_id] += 1
else:
async_wrapper.index[test_id] = 0
codeflash_test_index = async_wrapper.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(f"!$######{test_stdout_tag}######$!")
exception = None
counter = loop.time()
gc.disable()
try:
ret = func(*args, **kwargs)
counter = loop.time()
return_value = await ret
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
except Exception as e:
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
exception = e
finally:
gc.enable()
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
if exception:
raise exception
return return_value
return async_wrapper

View file

@ -3,6 +3,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60
N_CANDIDATES = 5
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
MAX_TEST_FUNCTION_RUNS = 50
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
N_TESTS_TO_GENERATE = 2

View file

@ -15,7 +15,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
dependent_functions = set()
for code_string in code_context.testgen_context.code_strings:
ast_tree = ast.parse(code_string.code)
dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)})
dependent_functions.update(
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
)
if main_function in dependent_functions:
dependent_functions.discard(main_function)
@ -43,17 +45,25 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
def generate_candidates(source_code_path: Path) -> set[str]:
"""Generate all the possible candidates for coverage data based on the source code path."""
candidates = set()
candidates.add(source_code_path.name)
current_path = source_code_path.parent
# Add the filename as a candidate
name = source_code_path.name
candidates.add(name)
last_added = source_code_path.name
while current_path != current_path.parent:
candidate_path = str(Path(current_path.name) / last_added)
# Precompute parts for efficient candidate path construction
parts = source_code_path.parts
n = len(parts)
# Walk up the directory structure without creating Path objects or repeatedly converting to posix
last_added = name
# Start from the last parent and move up to the root, exclusive (skip the root itself)
for i in range(n - 2, 0, -1):
# Combine the ith part with the accumulated path (last_added)
candidate_path = f"{parts[i]}/{last_added}"
candidates.add(candidate_path)
last_added = candidate_path
current_path = current_path.parent
candidates.add(str(source_code_path))
# Add the absolute posix path as a candidate
candidates.add(source_code_path.as_posix())
return candidates

View file

@ -32,9 +32,11 @@ class CommentMapper(ast.NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
self.context_stack.append(node.name)
for inner_node in ast.walk(node):
for inner_node in node.body:
if isinstance(inner_node, ast.FunctionDef):
self.visit_FunctionDef(inner_node)
elif isinstance(inner_node, ast.AsyncFunctionDef):
self.visit_AsyncFunctionDef(inner_node)
self.context_stack.pop()
return node
@ -50,6 +52,14 @@ class CommentMapper(ast.NodeVisitor):
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
self._process_function_def_common(node)
return node
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
self._process_function_def_common(node)
return node
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
self.context_stack.append(node.name)
i = len(node.body) - 1
test_qualified_name = ".".join(self.context_stack)
@ -60,8 +70,9 @@ class CommentMapper(ast.NodeVisitor):
j = len(line_node.body) - 1
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
internal_node: ast.AST
for internal_node in ast.walk(compound_line_node):
nodes_to_check = [compound_line_node]
nodes_to_check.extend(getattr(compound_line_node, "body", []))
for internal_node in nodes_to_check:
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id
@ -75,7 +86,6 @@ class CommentMapper(ast.NodeVisitor):
self.results[line_node.lineno] = self.get_comment(match_key)
i -= 1
self.context_stack.pop()
return node
def get_fn_call_linenos(
@ -197,23 +207,41 @@ def add_runtime_comments_to_generated_tests(
def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
# Pre-compile patterns for all function names to remove
function_patterns = _compile_function_patterns(test_functions_to_remove)
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)
source = generated_test.generated_original_test_source
match = function_pattern.search(generated_test.generated_original_test_source)
if match is None or "@pytest.mark.parametrize" in match.group(0):
continue
generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)
# Apply all patterns without redundant searches
for pattern in function_patterns:
# Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls
for match in pattern.finditer(source):
# Skip if "@pytest.mark.parametrize" present
# Only the matched function's code is targeted
if "@pytest.mark.parametrize" in match.group(0):
continue
# Remove function from source
# If match, remove the function by substitution in the source
# Replace using start/end indices for efficiency
start, end = match.span()
source = source[:start] + source[end:]
# After removal, break since .finditer() is from left to right, and only one match expected per function in source
break
generated_test.generated_original_test_source = source
new_generated_tests.append(generated_test)
return GeneratedTestsList(generated_tests=new_generated_tests)
# Pre-compile all function removal regexes upfront for efficiency.
def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]:
return [
re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
re.DOTALL,
)
for func in test_functions_to_remove
]

View file

@ -18,10 +18,9 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
if formatter_cmds[0] == "disabled":
return return_code
tmp_code = """print("hello world")"""
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f:
f.write(tmp_code)
f.flush()
tmp_file = Path(f.name)
with tempfile.TemporaryDirectory() as tmpdir:
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
tmp_file.write_text(tmp_code, encoding="utf-8")
try:
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure)
except Exception:
@ -29,7 +28,7 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
error_on_exit=True,
)
return return_code
return return_code
@lru_cache(maxsize=1)
@ -121,7 +120,7 @@ def get_cached_gh_event_data() -> dict[str, Any]:
event_path = os.getenv("GITHUB_EVENT_PATH")
if not event_path:
return {}
with Path(event_path).open() as f:
with Path(event_path).open(encoding="utf-8") as f:
return json.load(f) # type: ignore # noqa

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import json
import subprocess
import tempfile
import time
@ -9,15 +8,12 @@ from pathlib import Path
from typing import TYPE_CHECKING, Optional
import git
from filelock import FileLock
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import codeflash_cache_dir
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
if TYPE_CHECKING:
from typing import Any
from git import Repo
@ -100,56 +96,15 @@ def get_patches_dir_for_project() -> Path:
return Path(patches_dir / project_id)
def get_patches_metadata() -> dict[str, Any]:
project_patches_dir = get_patches_dir_for_project()
meta_file = project_patches_dir / "metadata.json"
if meta_file.exists():
with meta_file.open("r", encoding="utf-8") as f:
return json.load(f)
return {"id": get_git_project_id() or "", "patches": []}
def save_patches_metadata(patch_metadata: dict) -> dict:
project_patches_dir = get_patches_dir_for_project()
meta_file = project_patches_dir / "metadata.json"
lock_file = project_patches_dir / "metadata.json.lock"
# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
with FileLock(lock_file, timeout=10):
metadata = get_patches_metadata()
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
metadata["patches"].append(patch_metadata)
meta_file.write_text(json.dumps(metadata, indent=2))
return patch_metadata
def overwrite_patch_metadata(patches: list[dict]) -> bool:
project_patches_dir = get_patches_dir_for_project()
meta_file = project_patches_dir / "metadata.json"
lock_file = project_patches_dir / "metadata.json.lock"
with FileLock(lock_file, timeout=10):
metadata = get_patches_metadata()
metadata["patches"] = patches
meta_file.write_text(json.dumps(metadata, indent=2))
return True
def create_diff_patch_from_worktree(
worktree_dir: Path,
files: list[str],
fto_name: Optional[str] = None,
metadata_input: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
) -> Optional[Path]:
repository = git.Repo(worktree_dir, search_parent_directories=True)
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
if not uni_diff_text:
logger.warning("No changes found in worktree.")
return {}
return None
if not uni_diff_text.endswith("\n"):
uni_diff_text += "\n"
@ -157,14 +112,8 @@ def create_diff_patch_from_worktree(
project_patches_dir = get_patches_dir_for_project()
project_patches_dir.mkdir(parents=True, exist_ok=True)
final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
patch_path = project_patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
with patch_path.open("w", encoding="utf8") as f:
f.write(uni_diff_text)
final_metadata = {"patch_path": str(patch_path)}
if metadata_input:
final_metadata.update(metadata_input)
final_metadata = save_patches_metadata(final_metadata)
return final_metadata
return patch_path

View file

@ -1,10 +1,12 @@
from __future__ import annotations
import ast
import platform
from pathlib import Path
from typing import TYPE_CHECKING
import isort
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
@ -76,6 +78,10 @@ class InjectPerfOnly(ast.NodeTransformer):
call_node = node
if isinstance(node.func, ast.Name):
function_name = node.func.id
if self.function_object.is_async:
return [test_node]
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
@ -97,6 +103,9 @@ class InjectPerfOnly(ast.NodeTransformer):
if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
if function_to_test == self.function_object.function_name:
if self.function_object.is_async:
return [test_node]
function_name = ast.unparse(node.func)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
@ -135,7 +144,10 @@ class InjectPerfOnly(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
did_update = False
if self.test_framework == "unittest":
if self.test_framework == "unittest" and platform.system() != "Windows":
# Only add timeout decorator on non-Windows platforms
# Windows doesn't support SIGALRM signal required by timeout_decorator
node.decorator_list.append(
ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
@ -212,7 +224,9 @@ class InjectPerfOnly(ast.NodeTransformer):
args=[
ast.JoinedStr(
values=[
ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"),
ast.Constant(
value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}"
),
ast.FormattedValue(
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()),
conversion=-1,
@ -283,6 +297,168 @@ class InjectPerfOnly(ast.NodeTransformer):
return node
class AsyncCallInstrumenter(ast.NodeTransformer):
def __init__(
self,
function: FunctionToOptimize,
module_path: str,
test_framework: str,
call_positions: list[CodePosition],
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
self.mode = mode
self.function_object = function
self.class_name = None
self.only_function_name = function.function_name
self.module_path = module_path
self.test_framework = test_framework
self.call_positions = call_positions
self.did_instrument = False
# Track function call count per test function
self.async_call_counter: dict[str, int] = {}
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# Add timeout decorator for unittest test classes if needed
if self.test_framework == "unittest":
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
and item.name.startswith("test_")
and not any(
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
for d in item.decorator_list
)
):
item.decorator_list.append(timeout_decorator)
return self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
if not node.name.startswith("test_"):
return node
return self._process_test_function(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
# Only process test functions
if not node.name.startswith("test_"):
return node
return self._process_test_function(node)
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
# Optimize the search for decorator presence
if self.test_framework == "unittest":
found_timeout = False
for d in node.decorator_list:
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
if isinstance(d, ast.Call):
f = d.func
# Avoid attribute lookup if f is not ast.Name
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
found_timeout = True
break
if not found_timeout:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)
# Initialize counter for this test function
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0
new_body = []
# Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
for _i, stmt in enumerate(node.body):
transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt)
if added_env_assignment:
current_call_index = self.async_call_counter[node.name]
self.async_call_counter[node.name] += 1
env_assignment = ast.Assign(
targets=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"),
ctx=ast.Store(),
)
],
value=ast.Constant(value=f"{current_call_index}"),
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
)
new_body.append(env_assignment)
self.did_instrument = True
new_body.append(transformed_stmt)
node.body = new_body
return node
def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
for node in ast.walk(stmt):
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
return stmt, True # Return original statement but signal we added env var
return stmt, False
def _is_target_call(self, call_node: ast.Call) -> bool:
"""Check if this call node is calling our target async function."""
if isinstance(call_node.func, ast.Name):
return call_node.func.id == self.function_object.function_name
if isinstance(call_node.func, ast.Attribute):
return call_node.func.attr == self.function_object.function_name
return False
def _call_in_positions(self, call_node: ast.Call) -> bool:
if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"):
return False
return node_in_call_position(call_node, self.call_positions)
# Optimized version: only walk child nodes for Await
def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]:
# Stack-based DFS, manual for relevant Await nodes
stack = [stmt]
while stack:
node = stack.pop()
# Favor direct ast.Await detection
if isinstance(node, ast.Await):
val = node.value
if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val):
return stmt, True
# Use _fields instead of ast.walk for less allocations
for fname in getattr(node, "_fields", ()):
child = getattr(node, fname, None)
if isinstance(child, list):
stack.extend(child)
elif isinstance(child, ast.AST):
stack.append(child)
return stmt, False
class FunctionImportedAsVisitor(ast.NodeVisitor):
"""Checks if a function has been imported as an alias. We only care about the alias then.
@ -310,6 +486,7 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
file_path=self.function.file_path,
starting_line=self.function.starting_line,
ending_line=self.function.ending_line,
is_async=self.function.is_async,
)
else:
self.imported_as = FunctionToOptimize(
@ -318,9 +495,69 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
file_path=self.function.file_path,
starting_line=self.function.starting_line,
ending_line=self.function.ending_line,
is_async=self.function.is_async,
)
def instrument_source_module_with_async_decorators(
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
) -> tuple[bool, str | None]:
if not function_to_optimize.is_async:
return False, None
try:
with source_path.open(encoding="utf8") as f:
source_code = f.read()
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
if decorator_added:
return True, modified_code
except Exception:
return False, None
else:
return False, None
def inject_async_profiling_into_existing_test(
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
"""Inject profiling for async function calls by setting environment variables before each call."""
with test_path.open(encoding="utf8") as f:
test_code = f.read()
try:
tree = ast.parse(test_code)
except SyntaxError:
logger.exception(f"Syntax error in code in file - {test_path}")
return False, None
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
test_module_path = module_name_from_file_path(test_path, tests_project_root)
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree)
func = import_visitor.imported_as
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
tree = async_instrumenter.visit(tree)
if not async_instrumenter.did_instrument:
return False, None
# Add necessary imports
new_imports = [ast.Import(names=[ast.alias(name="os")])]
if test_framework == "unittest":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, *tree.body]
return True, isort.code(ast.unparse(tree), float_to_top=True)
def inject_profiling_into_existing_test(
test_path: Path,
call_positions: list[CodePosition],
@ -329,6 +566,11 @@ def inject_profiling_into_existing_test(
test_framework: str,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
)
with test_path.open(encoding="utf8") as f:
test_code = f.read()
try:
@ -336,7 +578,7 @@ def inject_profiling_into_existing_test(
except SyntaxError:
logger.exception(f"Syntax error in code in file - {test_path}")
return False, None
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
test_module_path = module_name_from_file_path(test_path, tests_project_root)
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree)
@ -352,9 +594,11 @@ def inject_profiling_into_existing_test(
new_imports.extend(
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
)
if test_framework == "unittest":
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
additional_functions = [create_wrapper_function(mode)]
tree.body = [*new_imports, *additional_functions, *tree.body]
return True, isort.code(ast.unparse(tree), float_to_top=True)
@ -735,3 +979,162 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
decorator_list=[],
returns=None,
)
class AsyncDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds async decorator to async function definitions."""
def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
"""Initialize the transformer.
Args:
----
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.
"""
super().__init__()
self.function = function
self.mode = mode
self.qualified_name_parts = function.qualified_name.split(".")
self.context_stack = []
self.added_decorator = False
# Choose decorator based on mode
self.decorator_name = (
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
# Track when we enter a function
self.context_stack.append(node.name.value)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Check if this is an async function and matches our target
if original_node.asynchronous is not None and self.context_stack == self.qualified_name_parts:
# Check if the decorator is already present
has_decorator = any(
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
)
# Only add the decorator if it's not already there
if not has_decorator:
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
# Add our new decorator to the existing decorators
updated_decorators = [new_decorator, *list(updated_node.decorators)]
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))
self.added_decorator = True
# Pop the context when we leave a function
self.context_stack.pop()
return updated_node
def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool:
"""Check if a decorator matches our target decorator name."""
if isinstance(decorator_node, cst.Name):
return decorator_node.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
}
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
return decorator_node.func.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
}
return False
class AsyncDecoratorImportAdder(cst.CSTTransformer):
"""Transformer that adds the import for async decorators."""
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
self.mode = mode
self.has_import = False
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
isinstance(node.module, cst.Attribute)
and isinstance(node.module.value, cst.Attribute)
and isinstance(node.module.value.value, cst.Name)
and node.module.value.value.value == "codeflash"
and node.module.value.attr.value == "code_utils"
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# If the import is already there, don't add it again
if self.has_import:
return updated_node
# Choose import based on mode
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
def add_async_decorator_to_function(
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
) -> tuple[str, bool]:
"""Add async decorator to an async function definition.
Args:
----
source_code: The source code to modify.
function: The FunctionToOptimize object representing the target async function.
mode: The testing mode to determine which decorator to apply.
Returns:
-------
Tuple of (modified_source_code, was_decorator_added).
"""
if not function.is_async:
return source_code, False
try:
module = cst.parse_module(source_code)
# Add the decorator to the function
decorator_transformer = AsyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)
# Add the import if decorator was added
if decorator_transformer.added_decorator:
import_transformer = AsyncDecoratorImportAdder(mode)
module = module.visit(import_transformer)
return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
except Exception as e:
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
return source_code, False
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:
instrumented_filename = f"instrumented_{source_path.name}"
return temp_dir / instrumented_filename

View file

@ -219,6 +219,6 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
file.write(modified_code)
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
file_contents = function_to_optimize.file_path.read_text("utf-8")
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
modified_code = add_profile_enable(file_contents, line_profile_output_file.as_posix())
function_to_optimize.file_path.write_text(modified_code, "utf-8")
return line_profile_output_file

View file

@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(
def get_first_top_level_function_or_method_ast(
function_name: str, parents: list[FunctionParent], node: ast.AST
) -> ast.FunctionDef | None:
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
if not parents:
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
if parents[0].type == "ClassDef" and (
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
):
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
return None

View file

@ -334,7 +334,9 @@ def extract_code_markdown_context_from_files(
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
),
)
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
code_string_context = CodeString(
code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve())
)
code_context_markdown.code_strings.append(code_string_context)
# Extract code from file paths containing helpers of helpers
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
@ -365,7 +367,9 @@ def extract_code_markdown_context_from_files(
project_root=project_root_path,
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
)
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
code_string_context = CodeString(
code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve())
)
code_context_markdown.code_strings.append(code_string_context)
return code_context_markdown

View file

@ -279,7 +279,7 @@ class QualifiedFunctionUsageMarker:
# Find class methods and add their containing classes and dunder methods
for qualified_name in list(self.qualified_function_names):
if "." in qualified_name:
class_name, method_name = qualified_name.split(".", 1)
class_name, _method_name = qualified_name.split(".", 1)
# Add the class itself
expanded.add(class_name)
@ -511,7 +511,7 @@ def revert_unused_helper_functions(
if not unused_helpers:
return
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
logger.debug(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
# Group unused helpers by file path
unused_helpers_by_file = defaultdict(list)
@ -612,6 +612,34 @@ def _analyze_imports_in_optimized_code(
return dict(imported_names_map)
def find_target_node(
root: ast.AST, function_to_optimize: FunctionToOptimize
) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]:
parents = function_to_optimize.parents
node = root
for parent in parents:
# Fast loop: directly look for the matching ClassDef in node.body
body = getattr(node, "body", None)
if not body:
return None
for child in body:
if isinstance(child, ast.ClassDef) and child.name == parent.name:
node = child
break
else:
return None
# Now node is either the root or the target parent class; look for function
body = getattr(node, "body", None)
if not body:
return None
target_name = function_to_optimize.function_name
for child in body:
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name:
return child
return None
def detect_unused_helper_functions(
function_to_optimize: FunctionToOptimize,
code_context: CodeOptimizationContext,
@ -641,11 +669,7 @@ def detect_unused_helper_functions(
optimized_ast = ast.parse(optimized_code)
# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
entrypoint_function_ast = node
break
entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize)
if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import ast
import enum
import hashlib
import os
import pickle
@ -11,12 +12,11 @@ import subprocess
import unittest
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, final
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
import pytest
from pydantic.dataclasses import dataclass
from rich.panel import Panel
from rich.text import Text
@ -35,6 +35,22 @@ if TYPE_CHECKING:
from codeflash.verification.verification_utils import TestConfig
@final
class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this
#: Tests passed.
OK = 0
#: Tests failed.
TESTS_FAILED = 1
#: pytest was interrupted.
INTERRUPTED = 2
#: An internal error got in the way.
INTERNAL_ERROR = 3
#: pytest was misused.
USAGE_ERROR = 4
#: pytest couldn't find tests.
NO_TESTS_COLLECTED = 5
@dataclass(frozen=True)
class TestFunction:
function_name: str
@ -412,7 +428,7 @@ def discover_tests_pytest(
error_section = match.group(1) if match else result.stdout
logger.warning(
f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}"
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
)
if "ModuleNotFoundError" in result.stdout:
match = ImportErrorPattern.search(result.stdout).group()
@ -420,7 +436,7 @@ def discover_tests_pytest(
console.print(panel)
elif 0 <= exitcode <= 5:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}")
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
else:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
console.rule()

View file

@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor):
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
is_async=bool(node.asynchronous),
)
)
@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
)
)
def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
@ -122,6 +132,7 @@ class FunctionToOptimize:
parents: A list of parent scopes, which could be classes or functions.
starting_line: The starting line number of the function in the file.
ending_line: The ending line number of the function in the file.
is_async: Whether this function is defined as async.
The qualified_name property provides the full name of the function, including
any parent class or function names. The qualified_name_with_modules_from_root
@ -134,6 +145,7 @@ class FunctionToOptimize:
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
is_async: bool = False
@property
def top_level_parent_name(self) -> str:
@ -147,7 +159,11 @@ class FunctionToOptimize:
@property
def qualified_name(self) -> str:
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
if not self.parents:
return self.function_name
# Join all parent names with dots to handle nested classes properly
parent_path = ".".join(parent.name for parent in self.parents)
return f"{parent_path}.{self.function_name}"
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
@ -163,6 +179,8 @@ def get_functions_to_optimize(
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
*,
enable_async: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
@ -216,7 +234,13 @@ def get_functions_to_optimize(
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff(uncommitted_changes=False)
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
functions,
test_cfg.tests_root,
ignore_paths,
project_root,
module_root,
previous_checkpoint_functions,
enable_async=enable_async,
)
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
@ -411,11 +435,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
)
)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if self.class_name is None and node.name == self.function_name:
self.is_top_level = True
self.function_has_args = any(
(
bool(node.args.args),
bool(node.args.kwonlyargs),
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
)
)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
# iterate over the class methods
if node.name == self.class_name:
for body_node in node.body:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
if (
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
):
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
@ -433,7 +473,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
@ -535,7 +575,9 @@ def filter_functions(
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
disable_logs: bool = False, # noqa: FBT001, FBT002
*,
disable_logs: bool = False,
enable_async: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
blocklist_funcs = get_blocklisted_functions()
@ -555,6 +597,7 @@ def filter_functions(
submodule_ignored_paths_count: int = 0
blocklist_funcs_removed_count: int = 0
previous_checkpoint_functions_removed_count: int = 0
async_functions_removed_count: int = 0
tests_root_str = str(tests_root)
module_root_str = str(module_root)
@ -610,6 +653,15 @@ def filter_functions(
functions_tmp.append(function)
_functions = functions_tmp
if not enable_async:
functions_tmp = []
for function in _functions:
if function.is_async:
async_functions_removed_count += 1
continue
functions_tmp.append(function)
_functions = functions_tmp
filtered_modified_functions[file_path] = _functions
functions_count += len(_functions)
@ -623,6 +675,7 @@ def filter_functions(
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
"Async functions removed": (async_functions_removed_count, "bright_magenta"),
}
tree = Tree(Text("Ignored functions and files", style="bold"))
for label, (count, color) in log_info.items():

View file

@ -12,11 +12,8 @@ from pygls import uris
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.cli_cmds.console import code_print
from codeflash.code_utils.git_worktree_utils import (
create_diff_patch_from_worktree,
get_patches_metadata,
overwrite_patch_metadata,
)
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import (
filter_functions,
@ -39,10 +36,17 @@ class OptimizableFunctionsParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
@dataclass
class FunctionOptimizationInitParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
functionName: str # noqa: N815
@dataclass
class FunctionOptimizationParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
functionName: str # noqa: N815
task_id: str
@dataclass
@ -59,7 +63,7 @@ class ValidateProjectParams:
@dataclass
class OnPatchAppliedParams:
patch_id: str
task_id: str
@dataclass
@ -111,7 +115,6 @@ def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
) -> dict[str, list[str]]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
if not server.optimizer:
return {"status": "error", "message": "optimizer not initialized"}
@ -119,55 +122,15 @@ def get_optimizable_functions(
server.optimizer.args.function = None # Always get ALL functions, not just one
server.optimizer.args.previous_checkpoint_functions = False
server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
path_to_qualified_names = {}
for functions in optimizable_funcs.values():
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]
server.show_message_log(
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
)
return path_to_qualified_names
@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid(server)
server.optimizer.worktree_mode()
original_args, _ = server.optimizer.original_args_and_test_cfg
server.optimizer.args.function = params.functionName
original_relative_file_path = file_path.relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False
server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
fto = optimizable_funcs.popitem()[1][0]
server.optimizer.current_function_being_optimized = fto
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
return {"functionName": params.functionName, "status": "success"}
def _find_pyproject_toml(workspace_path: str) -> Path | None:
workspace_path_obj = Path(workspace_path)
max_depth = 2
@ -207,13 +170,18 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
if pyproject_toml_path:
server.prepare_optimizer_arguments(pyproject_toml_path)
else:
return {
"status": "error",
"message": "No pyproject.toml found in workspace.",
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
return {"status": "error", "message": "No pyproject.toml found in workspace."}
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
root = str(git_root_dir())
if getattr(params, "skip_validation", False):
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
return {
"status": "success",
"moduleRoot": server.args.module_root,
"pyprojectPath": pyproject_toml_path,
"root": root,
}
server.show_message_log("Validating project...", "Info")
config = is_valid_pyproject_toml(pyproject_toml_path)
@ -234,7 +202,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
except Exception:
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
def _initialize_optimizer_if_api_key_is_valid(
@ -296,78 +264,85 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
return {"status": "error", "message": "something went wrong while saving the api key"}
@server.feature("retrieveSuccessfulOptimizations")
def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
metadata = get_patches_metadata()
return {"status": "success", "patches": metadata["patches"]}
@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid(server)
@server.feature("onPatchApplied")
def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]:
# first remove the patch from the metadata
metadata = get_patches_metadata()
server.optimizer.worktree_mode()
deleted_patch_file = None
new_patches = []
for patch in metadata["patches"]:
if patch["id"] == params.patch_id:
deleted_patch_file = patch["patch_path"]
continue
new_patches.append(patch)
original_args, _ = server.optimizer.original_args_and_test_cfg
# then remove the patch file
if deleted_patch_file:
overwrite_patch_metadata(new_patches)
patch_path = Path(deleted_patch_file)
patch_path.unlink(missing_ok=True)
return {"status": "success"}
return {"status": "error", "message": "Patch not found"}
server.optimizer.args.function = params.functionName
original_relative_file_path = file_path.relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False
server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
fto = optimizable_funcs.popitem()[1][0]
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}
validated_original_code, original_module_ast = module_prep_result
function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
server.current_optimization_init_result = initialization_result.unwrap()
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
files = [function_optimizer.function_to_optimize.file_path]
_, _, original_helpers = server.current_optimization_init_result
files.extend([str(helper_path) for helper_path in original_helpers])
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
@server.feature("performFunctionOptimization")
@server.thread()
def perform_function_optimization( # noqa: PLR0911
def perform_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
try:
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
current_function = server.optimizer.current_function_being_optimized
if not current_function:
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
return {
"functionName": params.functionName,
"status": "error",
"message": "No function currently being optimized",
}
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}
validated_original_code, original_module_ast = module_prep_result
function_optimizer = server.optimizer.create_function_optimizer(
current_function,
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=current_function.file_path,
function_to_tests={},
)
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
function_optimizer = server.optimizer.current_function_optimizer
current_function = function_optimizer.function_to_optimize
code_print(
code_context.read_writable_code.flat,
@ -447,22 +422,17 @@ def perform_function_optimization( # noqa: PLR0911
speedup = original_code_baseline.runtime / best_optimization.runtime
# get the original file path in the actual project (not in the worktree)
original_args, _ = server.optimizer.original_args_and_test_cfg
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
metadata = create_diff_patch_from_worktree(
server.optimizer.current_worktree,
relative_file_paths,
metadata_input={
"fto_name": function_to_optimize_qualified_name,
"explanation": best_optimization.explanation_v2,
"file_path": str(original_file_path),
"speedup": speedup,
},
patch_path = create_diff_patch_from_worktree(
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
)
if not patch_path:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to create a patch for optimization",
}
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
return {
@ -470,8 +440,8 @@ def perform_function_optimization( # noqa: PLR0911
"status": "success",
"message": "Optimization completed successfully",
"extra": f"Speedup: {speedup:.2f}x faster",
"patch_file": metadata["patch_path"],
"patch_id": metadata["id"],
"patch_file": str(patch_path),
"task_id": params.task_id,
"explanation": best_optimization.explanation_v2,
}
finally:

View file

@ -29,15 +29,15 @@ def tree_to_markdown(tree: Tree, level: int = 0) -> str:
def report_to_markdown_table(report: dict[TestType, dict[str, int]], title: str) -> str:
lines = ["| Test Type | Passed ✅ | Failed ❌ |", "|-----------|--------|--------|"]
lines = ["| Test Type | Passed ✅ |", "|-----------|--------|"]
for test_type in TestType:
if test_type is TestType.INIT_STATE_TEST:
continue
passed = report[test_type]["passed"]
failed = report[test_type]["failed"]
if passed == 0 and failed == 0:
# failed = report[test_type]["failed"]
if passed == 0:
continue
lines.append(f"| {test_type.to_name()} | {passed} | {failed} |")
lines.append(f"| {test_type.to_name()} | {passed} |")
table = "\n".join(lines)
if title:
return f"### {title}\n{table}"

View file

@ -3,10 +3,10 @@ from __future__ import annotations
import logging
import sys
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.lsp.lsp_message import LspTextMessage
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
root_logger = None
@ -43,16 +43,19 @@ def add_heading_tags(msg: str, tags: LspMessageTags) -> str:
return msg
def extract_tags(msg: str) -> tuple[Optional[LspMessageTags], str]:
def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
delimiter = "|"
parts = msg.split(delimiter)
if len(parts) == 2:
first_delim_idx = msg.find(delimiter)
if first_delim_idx != -1 and msg.count(delimiter) == 1:
tags_str = msg[:first_delim_idx]
content = msg[first_delim_idx + 1 :]
tags = {tag.strip() for tag in tags_str.split(",")}
message_tags = LspMessageTags()
tags = {tag.strip() for tag in parts[0].split(",")}
if "!lsp" in tags:
message_tags.not_lsp = True
# manually check and set to avoid repeated membership tests
if "lsp" in tags:
message_tags.lsp = True
if "!lsp" in tags:
message_tags.not_lsp = True
if "force_lsp" in tags:
message_tags.force_lsp = True
if "loading" in tags:
@ -67,9 +70,9 @@ def extract_tags(msg: str) -> tuple[Optional[LspMessageTags], str]:
message_tags.h3 = True
if "h4" in tags:
message_tags.h4 = True
return message_tags, delimiter.join(parts[1:])
return message_tags, content
return None, msg
return LspMessageTags(), msg
supported_lsp_log_levels = ("info", "debug")
@ -86,31 +89,32 @@ def enhanced_log(
actual_log_fn(msg, *args, **kwargs)
return
is_lsp_json_message = msg.startswith('{"type"')
is_lsp_json_message = msg.startswith(message_delimiter) and msg.endswith(message_delimiter)
is_normal_text_message = not is_lsp_json_message
# extract tags only from the text messages (not the json ones)
tags, clean_msg = extract_tags(msg) if is_normal_text_message else (None, msg)
# Extract tags only from text messages
tags, clean_msg = extract_tags(msg) if is_normal_text_message else (LspMessageTags(), msg)
lsp_enabled = is_LSP_enabled()
lsp_only = tags and tags.lsp
if not lsp_enabled and not lsp_only:
# normal logging
actual_log_fn(clean_msg, *args, **kwargs)
return
#### LSP mode ####
final_tags = tags if tags else LspMessageTags()
unsupported_level = level not in supported_lsp_log_levels
if not final_tags.force_lsp and (final_tags.not_lsp or unsupported_level):
return
# ---- Normal logging path ----
if not tags.lsp:
if not lsp_enabled: # LSP disabled
actual_log_fn(clean_msg, *args, **kwargs)
return
if tags.not_lsp: # explicitly marked as not for LSP
actual_log_fn(clean_msg, *args, **kwargs)
return
if unsupported_level and not tags.force_lsp: # unsupported level
actual_log_fn(clean_msg, *args, **kwargs)
return
# ---- LSP logging path ----
if is_normal_text_message:
clean_msg = add_heading_tags(clean_msg, final_tags)
clean_msg = add_highlight_tags(clean_msg, final_tags)
clean_msg = LspTextMessage(text=clean_msg, takes_time=final_tags.loading).serialize()
clean_msg = add_heading_tags(clean_msg, tags)
clean_msg = add_highlight_tags(clean_msg, tags)
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
actual_log_fn(clean_msg, *args, **kwargs)

View file

@ -10,6 +10,9 @@ from codeflash.lsp.helpers import replace_quotes_with_backticks, simplify_worktr
json_primitive_types = (str, float, int, bool)
max_code_lines_before_collapse = 45
# \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
message_delimiter = "\u241f"
@dataclass
class LspMessage:
@ -32,12 +35,8 @@ class LspMessage:
def serialize(self) -> str:
data = self._loop_through(asdict(self))
# Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it
ordered = {"type": self.type(), **data}
return (
json.dumps(ordered)
+ "\u241f" # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
)
return message_delimiter + json.dumps(ordered) + message_delimiter
@dataclass

View file

@ -9,6 +9,7 @@ from pygls.server import LanguageServer
if TYPE_CHECKING:
from pathlib import Path
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.optimizer import Optimizer
@ -22,6 +23,7 @@ class CodeflashLanguageServer(LanguageServer):
self.optimizer: Optimizer | None = None
self.args_processed_before: bool = False
self.args = None
self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None
def prepare_optimizer_arguments(self, config_file: Path) -> None:
from codeflash.cli_cmds.cli import parse_args
@ -57,6 +59,7 @@ class CodeflashLanguageServer(LanguageServer):
self.lsp.notify("window/logMessage", log_params)
def cleanup_the_optimizer(self) -> None:
self.current_optimization_init_result = None
if not self.optimizer:
return
try:

View file

@ -103,6 +103,7 @@ class BestOptimization(BaseModel):
winning_benchmarking_test_results: TestResults
winning_replay_benchmarking_test_results: Optional[TestResults] = None
line_profiler_test_results: dict
async_throughput: Optional[int] = None
@dataclass(frozen=True)
@ -208,7 +209,7 @@ class CodeStringsMarkdown(BaseModel):
"""
return "\n".join(
[
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
f"```python{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
for code_string in self.code_strings
]
)
@ -277,6 +278,7 @@ class OptimizedCandidateResult(BaseModel):
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
optimization_candidate_index: int
total_candidate_timing: int
async_throughput: Optional[int] = None
class GeneratedTests(BaseModel):
@ -383,6 +385,7 @@ class OriginalCodeBaseline(BaseModel):
line_profile_results: dict
runtime: int
coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
class CoverageStatus(Enum):
@ -545,6 +548,7 @@ class TestResults(BaseModel): # noqa: PLW1641
# also we don't support deletion of test results elements - caution is advised
test_results: list[FunctionTestInvocation] = []
test_result_idx: dict[str, int] = {}
perf_stdout: Optional[str] = None
def add(self, function_test_invocation: FunctionTestInvocation) -> None:
unique_id = function_test_invocation.unique_invocation_loop_id

View file

@ -36,7 +36,6 @@ from codeflash.code_utils.code_utils import (
diff_length,
file_name_from_test_module_name,
get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
unified_diff_strings,
@ -85,14 +84,20 @@ from codeflash.models.models import (
TestType,
)
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.critic import (
coverage_critic,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
throughput_gain,
)
from codeflash.result.explanation import Explanation
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
@ -199,7 +204,7 @@ class FunctionOptimizer:
test_cfg: TestConfig,
function_to_optimize_source_code: str = "",
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
@ -259,11 +264,6 @@ class FunctionOptimizer:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
async_code = any(
has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings
)
if async_code:
return Failure("Codeflash does not support async functions in the code to optimize.")
# Random here means that we still attempt optimization with a fractional chance to see if
# last time we could not find an optimization, maybe this time we do.
# Random is before as a performance optimization, swapping the two 'and' statements has the same effect
@ -304,7 +304,7 @@ class FunctionOptimizer:
]
with progress_bar(
f"Generating new tests and optimizations for function {self.function_to_optimize.function_name}",
f"Generating new tests and optimizations for function '{self.function_to_optimize.function_name}'",
transient=True,
revert_to_print=bool(get_pr_number()),
):
@ -587,7 +587,11 @@ class FunctionOptimizer:
tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛")
benchmark_tree = None
if speedup_critic(
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
candidate_result,
original_code_baseline.runtime,
best_runtime_until_now=None,
original_async_throughput=original_code_baseline.async_throughput,
best_throughput_until_now=None,
) and quantity_of_tests_critic(candidate_result):
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
@ -598,6 +602,17 @@ class FunctionOptimizer:
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
if (
original_code_baseline.async_throughput is not None
and candidate_result.async_throughput is not None
):
throughput_gain_value = throughput_gain(
original_throughput=original_code_baseline.async_throughput,
optimized_throughput=candidate_result.async_throughput,
)
tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions")
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
line_profile_test_results = self.line_profiler_step(
code_context=code_context,
original_helper_code=original_helper_code,
@ -633,6 +648,7 @@ class FunctionOptimizer:
replay_performance_gain=replay_perf_gain if self.args.benchmark else None,
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
async_throughput=candidate_result.async_throughput,
)
valid_optimizations.append(best_optimization)
# queue corresponding refined optimization for best optimization
@ -657,6 +673,15 @@ class FunctionOptimizer:
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
if (
original_code_baseline.async_throughput is not None
and candidate_result.async_throughput is not None
):
throughput_gain_value = throughput_gain(
original_throughput=original_code_baseline.async_throughput,
optimized_throughput=candidate_result.async_throughput,
)
tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%")
if is_LSP_enabled():
lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree)))
@ -700,6 +725,7 @@ class FunctionOptimizer:
replay_performance_gain=valid_opt.replay_performance_gain,
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
async_throughput=valid_opt.async_throughput,
)
valid_candidates_with_shorter_code.append(new_best_opt)
diff_lens_list.append(
@ -1079,6 +1105,7 @@ class FunctionOptimizer:
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
n_candidates,
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
is_async=self.function_to_optimize.is_async,
)
future_candidates_exp = None
@ -1094,6 +1121,7 @@ class FunctionOptimizer:
self.function_trace_id[:-4] + "EXP1",
n_candidates,
ExperimentMetadata(id=self.experiment_id, group="experiment"),
is_async=self.function_to_optimize.is_async,
)
futures.append(future_candidates_exp)
@ -1280,6 +1308,8 @@ class FunctionOptimizer:
function_name=function_to_optimize_qualified_name,
file_path=self.function_to_optimize.file_path,
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
original_async_throughput=original_code_baseline.async_throughput,
best_async_throughput=best_optimization.async_throughput,
)
self.replace_function_and_helpers_with_optimized_code(
@ -1362,6 +1392,23 @@ class FunctionOptimizer:
original_runtimes_all=original_runtime_by_test,
optimized_runtimes_all=optimized_runtime_by_test,
)
original_throughput_str = None
optimized_throughput_str = None
throughput_improvement_str = None
if (
self.function_to_optimize.is_async
and original_code_baseline.async_throughput is not None
and best_optimization.async_throughput is not None
):
original_throughput_str = f"{original_code_baseline.async_throughput} operations/second"
optimized_throughput_str = f"{best_optimization.async_throughput} operations/second"
throughput_improvement_value = throughput_gain(
original_throughput=original_code_baseline.async_throughput,
optimized_throughput=best_optimization.async_throughput,
)
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
source_code=code_context.read_writable_code.flat,
dependency_code=code_context.read_only_context_code,
@ -1375,6 +1422,9 @@ class FunctionOptimizer:
annotated_tests=generated_tests_str,
optimization_id=best_optimization.candidate.optimization_id,
original_explanation=best_optimization.candidate.explanation,
original_throughput=original_throughput_str,
optimized_throughput=optimized_throughput_str,
throughput_improvement=throughput_improvement_str,
)
new_explanation = Explanation(
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
@ -1385,6 +1435,8 @@ class FunctionOptimizer:
function_name=explanation.function_name,
file_path=explanation.file_path,
benchmark_details=explanation.benchmark_details,
original_async_throughput=explanation.original_async_throughput,
best_async_throughput=explanation.best_async_throughput,
)
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
@ -1475,6 +1527,17 @@ class FunctionOptimizer:
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
success, instrumented_source = instrument_source_module_with_async_decorators(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
# Instrument codeflash capture
with progress_bar("Running tests to establish original code behavior..."):
try:
@ -1514,15 +1577,38 @@ class FunctionOptimizer:
)
console.rule()
with progress_bar("Running performance benchmarks..."):
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import (
instrument_source_module_with_async_decorators,
)
success, instrumented_source = instrument_source_module_with_async_decorators(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
)
try:
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=total_looping_time,
enable_coverage=False,
code_context=code_context,
)
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
)
else:
benchmarking_results = TestResults()
start_time: float = time.time()
@ -1570,12 +1656,20 @@ class FunctionOptimizer:
loop_count = max([int(result.loop_index) for result in benchmarking_results.test_results])
logger.info(
f"h2|⌚ Original code summed runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: "
f"{humanize_runtime(total_timing)} per full loop"
f"h3|⌚ Original code summed runtime measured over '{loop_count}' loop{'s' if loop_count > 1 else ''}: "
f"'{humanize_runtime(total_timing)}' per full loop"
)
console.rule()
logger.debug(f"Total original code runtime (ns): {total_timing}")
async_throughput = None
if self.function_to_optimize.is_async:
async_throughput = calculate_function_throughput_from_test_results(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
console.rule()
if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -1589,6 +1683,7 @@ class FunctionOptimizer:
runtime=total_timing,
coverage_results=coverage_results,
line_profile_results=line_profile_results,
async_throughput=async_throughput,
),
functions_to_remove,
)
@ -1617,6 +1712,21 @@ class FunctionOptimizer:
candidate_helper_code = {}
for module_abspath in original_helper_code:
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import (
instrument_source_module_with_async_decorators,
)
success, instrumented_source = instrument_source_module_with_async_decorators(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(
f"Applied async behavioral instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
)
try:
instrument_codeflash_capture(
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
@ -1654,14 +1764,37 @@ class FunctionOptimizer:
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
if test_framework == "pytest":
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
# For async functions, instrument at definition site for performance benchmarking
if self.function_to_optimize.is_async:
from codeflash.code_utils.instrument_existing_tests import (
instrument_source_module_with_async_decorators,
)
success, instrumented_source = instrument_source_module_with_async_decorators(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
)
if success and instrumented_source:
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
f.write(instrumented_source)
logger.debug(
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
)
try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=optimization_candidate_index,
testing_time=total_looping_time,
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
loop_count = (
max(all_loop_indices)
if (
@ -1697,6 +1830,14 @@ class FunctionOptimizer:
console.rule()
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
candidate_async_throughput = None
if self.function_to_optimize.is_async:
candidate_async_throughput = calculate_function_throughput_from_test_results(
candidate_benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -1716,6 +1857,7 @@ class FunctionOptimizer:
else None,
optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing,
async_throughput=candidate_async_throughput,
)
)
@ -1807,8 +1949,10 @@ class FunctionOptimizer:
coverage_database_file=coverage_database_file,
coverage_config_file=coverage_config_file,
)
else:
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
if testing_type == TestingMode.PERFORMANCE:
results.perf_stdout = run_result.stdout
return results, coverage_results
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
return results, coverage_results
def submit_test_generation_tasks(
@ -1864,7 +2008,6 @@ class FunctionOptimizer:
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
) -> dict:
try:
logger.info("Running line profiling to identify performance bottlenecks…")
console.rule()
test_env = self.get_test_env(

View file

@ -15,7 +15,7 @@ from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
from codeflash.code_utils.git_utils import check_running_in_git_repo
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
from codeflash.code_utils.git_worktree_utils import (
create_detached_worktree,
create_diff_patch_from_worktree,
@ -134,12 +134,13 @@ class Optimizer:
project_root=self.args.project_root,
module_root=self.args.module_root,
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
enable_async=getattr(self.args, "async", False),
)
def create_function_optimizer(
self,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
@ -349,13 +350,12 @@ class Optimizer:
relative_file_paths = [
code_string.file_path for code_string in read_writable_code.code_strings
]
metadata = create_diff_patch_from_worktree(
patch_path = create_diff_patch_from_worktree(
self.current_worktree,
relative_file_paths,
fto_name=function_to_optimize.qualified_name,
metadata_input={},
)
self.patch_files.append(metadata["patch_path"])
self.patch_files.append(patch_path)
if i < len(functions_to_optimize) - 1:
create_worktree_snapshot_commit(
self.current_worktree,
@ -442,39 +442,50 @@ class Optimizer:
logger.warning("Failed to create worktree. Skipping optimization.")
return
self.current_worktree = worktree_dir
self.mutate_args_for_worktree_mode(worktree_dir)
self.mirror_paths_for_worktree_mode(worktree_dir)
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
saved_args = copy.deepcopy(self.args)
saved_test_cfg = copy.deepcopy(self.test_cfg)
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None:
original_args = copy.deepcopy(self.args)
original_test_cfg = copy.deepcopy(self.test_cfg)
self.original_args_and_test_cfg = (original_args, original_test_cfg)
project_root = self.args.project_root
module_root = self.args.module_root
relative_module_root = module_root.relative_to(project_root)
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
relative_benchmarks_root = (
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
original_git_root = git_root_dir()
# mirror project_root
self.args.project_root = mirror_path(self.args.project_root, original_git_root, worktree_dir)
self.test_cfg.project_root_path = mirror_path(self.test_cfg.project_root_path, original_git_root, worktree_dir)
# mirror module_root
self.args.module_root = mirror_path(self.args.module_root, original_git_root, worktree_dir)
# mirror target file
if self.args.file:
self.args.file = mirror_path(self.args.file, original_git_root, worktree_dir)
# mirror tests root
self.args.tests_root = mirror_path(self.args.tests_root, original_git_root, worktree_dir)
self.test_cfg.tests_root = mirror_path(self.test_cfg.tests_root, original_git_root, worktree_dir)
# mirror tests project root
self.args.test_project_root = mirror_path(self.args.test_project_root, original_git_root, worktree_dir)
self.test_cfg.tests_project_rootdir = mirror_path(
self.test_cfg.tests_project_rootdir, original_git_root, worktree_dir
)
self.args.module_root = worktree_dir / relative_module_root
self.args.project_root = worktree_dir
self.args.test_project_root = worktree_dir
self.args.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
# mirror benchmarks root paths
if self.args.benchmarks_root:
self.args.benchmarks_root = mirror_path(self.args.benchmarks_root, original_git_root, worktree_dir)
if self.test_cfg.benchmark_tests_root:
self.test_cfg.benchmark_tests_root = mirror_path(
self.test_cfg.benchmark_tests_root, original_git_root, worktree_dir
)
self.test_cfg.project_root_path = worktree_dir
self.test_cfg.tests_project_rootdir = worktree_dir
self.test_cfg.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
if relative_optimized_file is not None:
self.args.file = worktree_dir / relative_optimized_file
def mirror_path(path: Path, src_root: Path, dest_root: Path) -> Path:
relative_path = path.relative_to(src_root)
return dest_root / relative_path
def run_with_args(args: Namespace) -> None:

View file

@ -85,7 +85,7 @@ def existing_tests_source_for(
):
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
print_filename = filename.relative_to(tests_root)
print_filename = filename.resolve().relative_to(tests_root.resolve()).as_posix()
greater = (
optimized_tests_to_runtimes[filename][qualified_name]
> original_tests_to_runtimes[filename][qualified_name]
@ -192,9 +192,9 @@ def check_create_pr(
if pr_number is not None:
logger.info(f"Suggesting changes to PR #{pr_number} ...")
owner, repo = get_repo_owner_and_name(git_repo)
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix()
build_file_changes = {
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
@ -243,10 +243,10 @@ def check_create_pr(
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
return
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix()
base_branch = get_current_branch()
build_file_changes = {
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code

View file

@ -8,8 +8,9 @@ from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD,
MIN_IMPROVEMENT_THRESHOLD,
MIN_TESTCASE_PASSED_THRESHOLD,
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
)
from codeflash.models.test_type import TestType
from codeflash.models import models
if TYPE_CHECKING:
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
@ -25,20 +26,41 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) ->
return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns
def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float:
"""Calculate the throughput gain of an optimized code over the original code.
This value multiplied by 100 gives the percentage improvement in throughput.
For throughput, higher values are better (more executions per time period).
"""
if original_throughput == 0:
return 0.0
return (optimized_throughput - original_throughput) / original_throughput
def speedup_critic(
candidate_result: OptimizedCandidateResult,
original_code_runtime: int,
best_runtime_until_now: int | None,
*,
disable_gh_action_noise: bool = False,
original_async_throughput: int | None = None,
best_throughput_until_now: int | None = None,
) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
Ensure that the optimization is actually faster than the original code, above the noise floor.
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
Evaluates both runtime performance and async throughput improvements.
For runtime performance:
- Ensures the optimization is actually faster than the original code, above the noise floor.
- The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
- The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance.
For async throughput (when available):
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
- Throughput improvements complement runtime improvements for async functions
"""
# Runtime performance evaluation
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
if not disable_gh_action_noise and env_utils.is_ci():
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
@ -46,10 +68,31 @@ def speedup_critic(
perf_gain = performance_gain(
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
)
if best_runtime_until_now is None:
# collect all optimizations with this
return bool(perf_gain > noise_floor)
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
runtime_improved = perf_gain > noise_floor
# Check runtime comparison with best so far
runtime_is_best = best_runtime_until_now is None or candidate_result.best_test_runtime < best_runtime_until_now
throughput_improved = True # Default to True if no throughput data
throughput_is_best = True # Default to True if no throughput data
if original_async_throughput is not None and candidate_result.async_throughput is not None:
if original_async_throughput > 0:
throughput_gain_value = throughput_gain(
original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput
)
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
throughput_is_best = (
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
)
if original_async_throughput is not None and candidate_result.async_throughput is not None:
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
throughput_acceptance = throughput_improved and throughput_is_best
runtime_acceptance = runtime_improved and runtime_is_best
return throughput_acceptance or runtime_acceptance
return runtime_improved and runtime_is_best
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
@ -63,7 +106,7 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD:
return True
# If one or more tests passed, check if least one of them was a successful REPLAY_TEST
return bool(pass_count >= 1 and report[TestType.REPLAY_TEST]["passed"] >= 1)
return bool(pass_count >= 1 and report[models.TestType.REPLAY_TEST]["passed"] >= 1) # type: ignore # noqa: PGH003
def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool:

View file

@ -12,6 +12,7 @@ from rich.table import Table
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import BenchmarkDetail, TestResults
from codeflash.result.critic import throughput_gain
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@ -24,9 +25,28 @@ class Explanation:
function_name: str
file_path: Path
benchmark_details: Optional[list[BenchmarkDetail]] = None
original_async_throughput: Optional[int] = None
best_async_throughput: Optional[int] = None
@property
def perf_improvement_line(self) -> str:
runtime_improvement = self.speedup
if (
self.original_async_throughput is not None
and self.best_async_throughput is not None
and self.original_async_throughput > 0
):
throughput_improvement = throughput_gain(
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
)
# Use throughput metrics if throughput improvement is better or runtime got worse
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
throughput_pct = f"{throughput_improvement * 100:,.0f}%"
throughput_x = f"{throughput_improvement + 1:,.2f}x"
return f"{throughput_pct} improvement ({throughput_x} faster)."
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
@property
@ -46,6 +66,23 @@ class Explanation:
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
original_runtime_human = humanize_runtime(self.original_runtime_ns)
best_runtime_human = humanize_runtime(self.best_runtime_ns)
# Determine if we're showing throughput or runtime improvements
runtime_improvement = self.speedup
is_using_throughput_metric = False
if (
self.original_async_throughput is not None
and self.best_async_throughput is not None
and self.original_async_throughput > 0
):
throughput_improvement = throughput_gain(
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
)
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
is_using_throughput_metric = True
benchmark_info = ""
if self.benchmark_details:
@ -86,13 +123,18 @@ class Explanation:
console.print(table)
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
test_report = self.winning_behavior_test_results.get_test_pass_fail_report_by_type()
test_report_str = TestResults.report_to_string(test_report)
if is_using_throughput_metric:
performance_description = (
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
f"(runtime: {original_runtime_human}{best_runtime_human})\n\n"
)
else:
performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
return (
f"Optimized {self.function_name} in {self.file_path}\n"
f"{self.perf_improvement_line}\n"
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
+ performance_description
+ (benchmark_info if benchmark_info else "")
+ self.raw_explanation_message
+ " \n\n"
@ -101,7 +143,7 @@ class Explanation:
""
if is_LSP_enabled()
else "The new optimized code was tested for correctness. The results are listed below.\n"
+ test_report_str
f"{TestResults.report_to_string(self.winning_behavior_test_results.get_test_pass_fail_report_by_type())}\n"
)
)

View file

@ -13,6 +13,7 @@ import sys
import threading
import time
from collections import defaultdict
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar
@ -47,6 +48,17 @@ class FakeFrame:
self.f_locals: dict = {}
def patch_ap_scheduler() -> None:
if find_spec("apscheduler"):
import apscheduler.schedulers.background as bg
import apscheduler.schedulers.blocking as bb
from apscheduler.schedulers import base
bg.BackgroundScheduler.start = lambda _, *_a, **_k: None
bb.BlockingScheduler.start = lambda _, *_a, **_k: None
base.BaseScheduler.add_job = lambda _, *_a, **_k: None
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
class Tracer:
"""Use this class as a 'with' context manager to trace a function call.
@ -820,6 +832,7 @@ class Tracer:
if __name__ == "__main__":
args_dict = json.loads(sys.argv[-1])
sys.argv = sys.argv[1:-1]
patch_ap_scheduler()
if args_dict["module"]:
import runpy

View file

@ -1,4 +1,3 @@
# ruff: noqa: PGH003
import array
import ast
import datetime
@ -8,6 +7,7 @@ import math
import re
import types
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
from typing import Any
import sentry_sdk
@ -15,51 +15,14 @@ import sentry_sdk
from codeflash.cli_cmds.console import logger
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
try:
import numpy as np
HAS_NUMPY = True
except ImportError:
HAS_NUMPY = False
try:
import sqlalchemy # type: ignore
HAS_SQLALCHEMY = True
except ImportError:
HAS_SQLALCHEMY = False
try:
import scipy # type: ignore
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
try:
import pandas # type: ignore # noqa: ICN001
HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False
try:
import pyrsistent # type: ignore
HAS_PYRSISTENT = True
except ImportError:
HAS_PYRSISTENT = False
try:
import torch # type: ignore
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
try:
import jax # type: ignore
import jax.numpy as jnp # type: ignore
HAS_JAX = True
except ImportError:
HAS_JAX = False
HAS_NUMPY = find_spec("numpy") is not None
HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None
HAS_SCIPY = find_spec("scipy") is not None
HAS_PANDAS = find_spec("pandas") is not None
HAS_PYRSISTENT = find_spec("pyrsistent") is not None
HAS_TORCH = find_spec("torch") is not None
HAS_JAX = find_spec("jax") is not None
HAS_XARRAY = find_spec("xarray") is not None
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
@ -115,15 +78,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
return comparator(orig_dict, new_dict, superset_obj)
# Handle JAX arrays first to avoid boolean context errors in other conditions
if HAS_JAX and isinstance(orig, jax.Array):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
return bool(jnp.allclose(orig, new, equal_nan=True))
if HAS_JAX:
import jax # type: ignore # noqa: PGH003
import jax.numpy as jnp # type: ignore # noqa: PGH003
# Handle JAX arrays first to avoid boolean context errors in other conditions
if isinstance(orig, jax.Array):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
return bool(jnp.allclose(orig, new, equal_nan=True))
# Handle xarray objects before numpy to avoid boolean context errors
if HAS_XARRAY:
import xarray # type: ignore # noqa: PGH003
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
return orig.identical(new)
if HAS_SQLALCHEMY:
import sqlalchemy # type: ignore # noqa: PGH003
try:
insp = sqlalchemy.inspection.inspect(orig)
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
@ -138,6 +114,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
except sqlalchemy.exc.NoInspectionAvailable:
pass
if HAS_SCIPY:
import scipy # type: ignore # noqa: PGH003
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
if superset_obj:
@ -151,27 +130,30 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return True
if HAS_NUMPY and isinstance(orig, np.ndarray):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
if HAS_NUMPY:
import numpy as np # type: ignore # noqa: PGH003
if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
if isinstance(orig, np.ndarray):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
try:
return np.allclose(orig, new, equal_nan=True)
except Exception:
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)):
return orig == new
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
if HAS_NUMPY and isinstance(orig, np.void):
if orig.dtype != new.dtype:
return False
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
if isinstance(orig, (np.integer, np.bool_, np.byte)):
return orig == new
if isinstance(orig, np.void):
if orig.dtype != new.dtype:
return False
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
if orig.dtype != new.dtype:
@ -180,15 +162,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return (orig != new).nnz == 0
if HAS_PANDAS and isinstance(
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
return orig.equals(new)
if HAS_PANDAS:
import pandas # type: ignore # noqa: ICN001, PGH003
if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
return orig == new
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
return True
if isinstance(
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
):
return orig.equals(new)
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
return orig == new
if pandas.isna(orig) and pandas.isna(new):
return True
if isinstance(orig, array.array):
if orig.typecode != new.typecode:
@ -209,31 +194,58 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
except Exception: # noqa: S110
pass
if HAS_TORCH and isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
if orig.requires_grad != new.requires_grad:
return False
if orig.device != new.device:
return False
return torch.allclose(orig, new, equal_nan=True)
if HAS_TORCH:
import torch # type: ignore # noqa: PGH003
if HAS_PYRSISTENT and isinstance(
orig,
(
pyrsistent.PMap,
pyrsistent.PVector,
pyrsistent.PSet,
pyrsistent.PRecord,
pyrsistent.PClass,
pyrsistent.PBag,
pyrsistent.PList,
pyrsistent.PDeque,
),
):
return orig == new
if isinstance(orig, torch.Tensor):
if orig.dtype != new.dtype:
return False
if orig.shape != new.shape:
return False
if orig.requires_grad != new.requires_grad:
return False
if orig.device != new.device:
return False
return torch.allclose(orig, new, equal_nan=True)
if HAS_PYRSISTENT:
import pyrsistent # type: ignore # noqa: PGH003
if isinstance(
orig,
(
pyrsistent.PMap,
pyrsistent.PVector,
pyrsistent.PSet,
pyrsistent.PRecord,
pyrsistent.PClass,
pyrsistent.PBag,
pyrsistent.PList,
pyrsistent.PDeque,
),
):
return orig == new
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
orig_dict = {}
new_dict = {}
for attr in orig.__attrs_attrs__:
if attr.eq:
attr_name = attr.name
orig_dict[attr_name] = getattr(orig, attr_name, None)
new_dict[attr_name] = getattr(new, attr_name, None)
if superset_obj:
new_attrs_dict = {}
for attr in new.__attrs_attrs__:
if attr.eq:
attr_name = attr.name
new_attrs_dict[attr_name] = getattr(new, attr_name, None)
return all(
k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items()
)
return comparator(orig_dict, new_dict, superset_obj)
# re.Pattern can be made better by DFA Minimization and then comparing
if isinstance(

View file

@ -38,20 +38,20 @@ class CoverageUtils:
cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True)
if not database_path.stat().st_size or not database_path.exists():
if not database_path.exists() or not database_path.stat().st_size:
logger.debug(f"Coverage database {database_path} is empty or does not exist")
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
return CoverageData.create_empty(source_code_path, function_name, code_context)
cov.load()
reporter = JsonReporter(cov)
temp_json_file = database_path.with_suffix(".report.json")
with temp_json_file.open("w") as f:
with temp_json_file.open("w", encoding="utf-8") as f:
try:
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
except NoDataError:
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
return CoverageData.create_empty(source_code_path, function_name, code_context)
with temp_json_file.open() as f:
original_coverage_data = json.load(f)
@ -92,7 +92,7 @@ class CoverageUtils:
def _parse_coverage_file(
coverage_file_path: Path, source_code_path: Path
) -> tuple[dict[str, dict[str, Any]], CoverageStatus]:
with coverage_file_path.open() as f:
with coverage_file_path.open(encoding="utf-8") as f:
coverage_data = json.load(f)
candidates = generate_candidates(source_code_path)

View file

@ -33,7 +33,7 @@ def instrument_codeflash_capture(
modified_code = add_codeflash_capture_to_init(
target_classes={class_parent.name},
fto_name=function_to_optimize.function_name,
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
code=original_code,
tests_root=tests_root,
is_fto=True,
@ -46,7 +46,7 @@ def instrument_codeflash_capture(
modified_code = add_codeflash_capture_to_init(
target_classes=helper_classes,
fto_name=function_to_optimize.function_name,
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
code=original_code,
tests_root=tests_root,
is_fto=False,
@ -92,6 +92,27 @@ class InitDecorator(ast.NodeTransformer):
self.tests_root = tests_root
self.inserted_decorator = False
# Precompute decorator components to avoid reconstructing on every node visit
# Only the `function_name` field changes per class
self._base_decorator_keywords = [
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
]
self._base_decorator_func = ast.Name(id="codeflash_capture", ctx=ast.Load())
# Preconstruct starred/kwargs for super init injection for perf
self._super_starred = ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))
self._super_kwarg = ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))
self._super_func = ast.Attribute(
value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]),
attr="__init__",
ctx=ast.Load(),
)
self._init_vararg = ast.arg(arg="args")
self._init_kwarg = ast.arg(arg="kwargs")
self._init_self_arg = ast.arg(arg="self", annotation=None)
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
@ -114,21 +135,18 @@ class InitDecorator(ast.NodeTransformer):
if node.name not in self.target_classes:
return node
# Look for __init__ method
has_init = False
# Create the decorator
# Build decorator node ONCE for each class, not per loop iteration
decorator = ast.Call(
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
func=self._base_decorator_func,
args=[],
keywords=[
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
ast.keyword(arg="tests_root", value=ast.Constant(value=str(self.tests_root))),
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
*self._base_decorator_keywords,
],
)
# Only scan node.body once for both __init__ and decorator check
for item in node.body:
if (
isinstance(item, ast.FunctionDef)
@ -139,35 +157,28 @@ class InitDecorator(ast.NodeTransformer):
):
has_init = True
# Add decorator at the start of the list if not already present
if not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture"
for d in item.decorator_list
):
# Check for existing decorator in-place, stop after finding one
for d in item.decorator_list:
if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture":
break
else:
# No decorator found
item.decorator_list.insert(0, decorator)
self.inserted_decorator = True
if not has_init:
# Create super().__init__(*args, **kwargs) call
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
super_call = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]),
attr="__init__",
ctx=ast.Load(),
),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
)
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
)
# Create function arguments: self, *args, **kwargs
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
arguments = ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="self", annotation=None)],
vararg=ast.arg(arg="args"),
args=[self._init_self_arg],
vararg=self._init_vararg,
kwonlyargs=[],
kw_defaults=[],
kwarg=ast.arg(arg="kwargs"),
kwarg=self._init_kwarg,
defaults=[],
)

View file

@ -40,6 +40,30 @@ matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int:
"""Calculate function throughput from TestResults by extracting performance stdout.
A completed execution is defined as having both a start tag and matching end tag from performance wrappers.
Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$!
End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######!
"""
start_matches = start_pattern.findall(test_results.perf_stdout or "")
end_matches = end_pattern.findall(test_results.perf_stdout or "")
end_matches_truncated = [end_match[:5] for end_match in end_matches]
end_matches_set = set(end_matches_truncated)
function_throughput = 0
for start_match in start_matches:
if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name:
function_throughput += 1
return function_throughput
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults()
if not file_location.exists():

View file

@ -450,3 +450,26 @@ class PytestLoops:
metafunc.parametrize(
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
)
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(self, item: pytest.Item) -> None:
"""Set test context environment variables before each test."""
test_module_name = item.module.__name__ if item.module else "unknown_module"
test_class_name = None
if item.cls:
test_class_name = item.cls.__name__
test_function_name = item.name
if "[" in test_function_name:
test_function_name = test_function_name.split("[", 1)[0]
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
"""Clean up test context environment variables after each test."""
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
os.environ.pop(var, None)

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.17.0"
__version__ = "0.17.2"

View file

@ -52,8 +52,14 @@ Homepage = "https://codeflash.ai"
[project.scripts]
codeflash = "codeflash.main:main"
[project.optional-dependencies]
asyncio = [
"pytest-asyncio>=1.2.0",
]
[dependency-groups]
dev = [
{include-group = "asyncio"},
"ipython>=8.12.0",
"mypy>=1.13",
"ruff>=0.7.0",
@ -76,6 +82,9 @@ dev = [
"uv>=0.6.2",
"pre-commit>=4.2.0,<5",
]
asyncio = [
"pytest-asyncio>=1.2.0",
]
[tool.hatch.build.targets.sdist]
include = ["codeflash"]

View file

@ -0,0 +1,28 @@
import os
import pathlib
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="main.py",
expected_unit_tests=0,
min_improvement_x=0.1,
enable_async=True,
coverage_expectations=[
CoverageExpectation(
function_name="retry_with_backoff",
expected_coverage=100.0,
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
)
],
)
cwd = (
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e"
).resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)
if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))

View file

@ -11,7 +11,7 @@ def run_test(expected_improvement_pct: int) -> bool:
function_name="sorter",
benchmarks_root=cwd / "tests" / "pytest" / "benchmarks",
test_framework="pytest",
min_improvement_x=1.0,
min_improvement_x=0.70,
coverage_expectations=[
CoverageExpectation(
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]

View file

@ -9,7 +9,7 @@ def run_test(expected_improvement_pct: int) -> bool:
file_path="bubble_sort.py",
function_name="sorter",
test_framework="pytest",
min_improvement_x=1.0,
min_improvement_x=0.70,
coverage_expectations=[
CoverageExpectation(
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]

View file

@ -6,7 +6,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=3.0
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.30
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)

View file

@ -27,6 +27,7 @@ class TestConfig:
trace_mode: bool = False
coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
benchmarks_root: Optional[pathlib.Path] = None
enable_async: bool = False
def clear_directory(directory_path: str | pathlib.Path) -> None:
@ -88,8 +89,10 @@ def run_codeflash_command(
test_root = cwd / "tests" / (config.test_framework or "")
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
env = os.environ.copy()
env['PYTHONIOENCODING'] = 'utf-8'
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
)
output = []
@ -122,7 +125,7 @@ def build_command(
) -> list[str]:
python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py"
base_command = ["python", python_path, "--file", config.file_path, "--no-pr"]
base_command = ["uv", "run", "--no-project", python_path, "--file", config.file_path, "--no-pr"]
if config.function_name:
base_command.extend(["--function", config.function_name])
@ -132,6 +135,8 @@ def build_command(
)
if benchmarks_root:
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
if config.enable_async:
base_command.append("--async")
return base_command
@ -187,9 +192,11 @@ def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) ->
def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
test_root = cwd / "tests" / (config.test_framework or "")
clear_directory(test_root)
command = ["python", "-m", "codeflash.main", "optimize", "workload.py"]
command = ["uv", "run", "--no-project", "-m", "codeflash.main", "optimize", "workload.py"]
env = os.environ.copy()
env['PYTHONIOENCODING'] = 'utf-8'
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
)
output = []

View file

@ -3,6 +3,10 @@ from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
import tempfile
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
import libcst as cst
from codeflash.models.models import FunctionParent
def test_add_needed_imports_from_module0() -> None:
src_module = '''import ast
@ -349,3 +353,141 @@ class DbtAdapter(BaseAdapter):
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
def test_resolve_star_import_with_all_defined():
"""Test resolve_star_import when __all__ is explicitly defined."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
# Create a test module with __all__ definition
test_module.write_text('''
__all__ = ['public_function', 'PublicClass']
def public_function():
pass
def _private_function():
pass
class PublicClass:
pass
class AnotherPublicClass:
"""Not in __all__ so should be excluded."""
pass
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_function', 'PublicClass'}
assert symbols == expected_symbols
def test_resolve_star_import_without_all_defined():
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'
# Create a test module without __all__ definition
test_module.write_text('''
def public_func():
pass
def _private_func():
pass
class PublicClass:
pass
PUBLIC_VAR = 42
_private_var = 'secret'
''')
symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
assert symbols == expected_symbols
def test_resolve_star_import_nonexistent_module():
"""Test resolve_star_import with non-existent module - should return empty set."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
symbols = resolve_star_import('nonexistent_module', project_root)
assert symbols == set()
def test_dotted_import_collector_skips_star_imports():
"""Test that DottedImportCollector correctly skips star imports."""
code_with_star_import = '''
from typing import *
from pathlib import Path
from collections import defaultdict
import os
'''
module = cst.parse_module(code_with_star_import)
collector = DottedImportCollector()
module.visit(collector)
# Should collect regular imports but skip the star import
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
assert collector.imports == expected_imports
def test_add_needed_imports_with_star_import_resolution():
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Create a source module that exports symbols
src_module = project_root / 'source_module.py'
src_module.write_text('''
__all__ = ['UtilFunction', 'HelperClass']
def UtilFunction():
pass
class HelperClass:
pass
''')
# Create source code that uses star import
src_code = '''
from source_module import *
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
# Destination code that needs the imports resolved
dst_code = '''
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
src_path = project_root / 'src.py'
dst_path = project_root / 'dst.py'
src_path.write_text(src_code)
result = add_needed_imports_from_module(
src_code, dst_code, src_path, dst_path, project_root
)
# The result should have individual imports instead of star import
expected_result = '''from source_module import HelperClass, UtilFunction
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
assert result == expected_result

View file

@ -1902,4 +1902,210 @@ def test_bubble_sort(input, expected_output):
# Check that comments were added
modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == expected
assert modified_source == expected
def test_async_basic_runtime_comment_addition(self, test_config):
"""Test basic functionality of adding runtime comments to async test functions."""
os.chdir(test_config.project_root_path)
test_source = """async def test_async_bubble_sort():
codeflash_output = await async_bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py",
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_invocation = self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0') # 500μs
optimized_invocation = self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0') # 300μs
original_test_results.add(original_invocation)
optimized_test_results.add(optimized_invocation)
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
def test_async_multiple_test_functions(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """async def test_async_bubble_sort():
codeflash_output = await async_quick_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
async def test_async_quick_sort():
codeflash_output = await async_quick_sort([5, 2, 8])
assert codeflash_output == [2, 5, 8]
def helper_function():
return "not a test"
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_async_quick_sort", 800_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_quick_sort", 600_000, iteration_id='0'))
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "# 800μs -> 600μs" in modified_source
assert (
"helper_function():" in modified_source
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
)
def test_async_class_method(self, test_config):
os.chdir(test_config.project_root_path)
test_source = '''class TestAsyncClass:
async def test_async_function(self):
codeflash_output = await some_async_function()
assert codeflash_output == expected
'''
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
invocation_id = InvocationId(
test_module_path="tests.test_module__unit_test_0",
test_class_name="TestAsyncClass",
test_function_name="test_async_function",
function_getting_tested="some_async_function",
iteration_id="0",
)
original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds
optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
expected_source = '''class TestAsyncClass:
async def test_async_function(self):
codeflash_output = await some_async_function() # 2.00s -> 1.00s (100% faster)
assert codeflash_output == expected
'''
assert len(result.generated_tests) == 1
assert result.generated_tests[0].generated_original_test_source == expected_source
def test_async_mixed_sync_and_async_functions(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """def test_sync_function():
codeflash_output = sync_function([1, 2, 3])
assert codeflash_output == [1, 2, 3]
async def test_async_function():
codeflash_output = await async_function([4, 5, 6])
assert codeflash_output == [4, 5, 6]
def test_another_sync():
result = another_sync_func()
assert result is True
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
# Add test invocations for all test functions
original_test_results.add(self.create_test_invocation("test_sync_function", 400_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_async_function", 600_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_another_sync", 200_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_sync_function", 200_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_function", 300_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_another_sync", 100_000, iteration_id='0'))
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 400μs -> 200μs" in modified_source
assert "# 600μs -> 300μs" in modified_source
assert "# 200μs -> 100μs" in modified_source
assert "async def test_async_function():" in modified_source
assert "await async_function([4, 5, 6])" in modified_source
def test_async_complex_await_patterns(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """async def test_complex_async():
# Multiple await calls
result1 = await async_func1()
codeflash_output = await async_func2(result1)
result3 = await async_func3(codeflash_output)
assert result3 == expected
# Await in context manager
async with async_context() as ctx:
final_result = await ctx.process()
assert final_result is not None
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_complex_async", 750_000, iteration_id='1')) # 750μs
optimized_test_results.add(self.create_test_invocation("test_complex_async", 450_000, iteration_id='1')) # 450μs
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 750μs -> 450μs" in modified_source

View file

@ -0,0 +1,347 @@
import tempfile
from pathlib import Path
import sys
import pytest
from codeflash.discovery.functions_to_optimize import (
find_all_functions_in_file,
get_functions_to_optimize,
inspect_top_level_functions_or_methods,
)
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_function_detection(temp_dir):
async_function = """
async def async_function_with_return():
await some_async_operation()
return 42
async def async_function_without_return():
await some_async_operation()
print("No return")
def regular_function():
return 10
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_function)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_function_with_return" in function_names
assert "regular_function" in function_names
assert "async_function_without_return" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_method_in_class(temp_dir):
code_with_async_method = """
class AsyncClass:
async def async_method(self):
await self.do_something()
return "result"
async def async_method_no_return(self):
await self.do_something()
pass
def sync_method(self):
return "sync result"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code_with_async_method)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
function_names = [fn.function_name for fn in found_functions]
qualified_names = [fn.qualified_name for fn in found_functions]
assert "async_method" in function_names
assert "AsyncClass.async_method" in qualified_names
assert "sync_method" in function_names
assert "AsyncClass.sync_method" in qualified_names
assert "async_method_no_return" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_nested_async_functions(temp_dir):
nested_async = """
async def outer_async():
async def inner_async():
return "inner"
result = await inner_async()
return result
def outer_sync():
async def inner_async():
return "inner from sync"
return inner_async
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(nested_async)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "outer_async" in function_names
assert "outer_sync" in function_names
assert "inner_async" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_staticmethod_and_classmethod(temp_dir):
async_decorators = """
class MyClass:
@staticmethod
async def async_static_method():
await some_operation()
return "static result"
@classmethod
async def async_class_method(cls):
await cls.some_operation()
return "class result"
@property
async def async_property(self):
return await self.get_value()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_decorators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_static_method" in function_names
assert "async_class_method" in function_names
assert "async_property" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_generator_functions(temp_dir):
async_generators = """
async def async_generator_with_return():
for i in range(10):
yield i
return "done"
async def async_generator_no_return():
for i in range(10):
yield i
async def regular_async_with_return():
result = await compute()
return result
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(async_generators)
functions_found = find_all_functions_in_file(file_path)
function_names = [fn.function_name for fn in functions_found[file_path]]
assert "async_generator_with_return" in function_names
assert "regular_async_with_return" in function_names
assert "async_generator_no_return" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inspect_async_top_level_functions(temp_dir):
code = """
async def top_level_async():
return 42
class AsyncContainer:
async def async_method(self):
async def nested_async():
return 1
return await nested_async()
@staticmethod
async def async_static():
return "static"
@classmethod
async def async_classmethod(cls):
return "classmethod"
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(code)
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
assert result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
assert not result.is_top_level
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_staticmethod
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
assert result.is_top_level
assert result.is_classmethod
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_get_functions_to_optimize_with_async(temp_dir):
mixed_code = """
async def async_func_one():
return await operation_one()
def sync_func_one():
return operation_one()
async def async_func_two():
print("no return")
class MixedClass:
async def async_method(self):
return await self.operation()
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=file_path,
only_get_this_function=None,
test_cfg=test_config,
ignore_paths=[],
project_root=file_path.parent,
module_root=file_path.parent,
enable_async=True,
)
assert functions_count == 4
function_names = [fn.function_name for fn in functions[file_path]]
assert "async_func_one" in function_names
assert "sync_func_one" in function_names
assert "async_method" in function_names
assert "sync_method" in function_names
assert "async_func_two" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_no_async_functions_finding(temp_dir):
mixed_code = """
async def async_func_one():
return await operation_one()
def sync_func_one():
return operation_one()
async def async_func_two():
print("no return")
class MixedClass:
async def async_method(self):
return await self.operation()
def sync_method(self):
return self.operation()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(mixed_code)
test_config = TestConfig(
tests_root="tests",
project_root_path=".",
test_framework="pytest",
tests_project_rootdir=Path()
)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=file_path,
only_get_this_function=None,
test_cfg=test_config,
ignore_paths=[],
project_root=file_path.parent,
module_root=file_path.parent,
enable_async=False,
)
assert functions_count == 2
function_names = [fn.function_name for fn in functions[file_path]]
assert "sync_func_one" in function_names
assert "sync_method" in function_names
assert "async_func_one" not in function_names
assert "async_method" not in function_names
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_function_parents(temp_dir):
complex_structure = """
class OuterClass:
async def outer_method(self):
return 1
class InnerClass:
async def inner_method(self):
return 2
async def module_level_async():
class LocalClass:
async def local_method(self):
return 3
return LocalClass()
"""
file_path = temp_dir / "test_file.py"
file_path.write_text(complex_structure)
functions_found = find_all_functions_in_file(file_path)
found_functions = functions_found[file_path]
for fn in found_functions:
if fn.function_name == "outer_method":
assert len(fn.parents) == 1
assert fn.parents[0].name == "OuterClass"
assert fn.qualified_name == "OuterClass.outer_method"
elif fn.function_name == "inner_method":
assert len(fn.parents) == 2
assert fn.parents[0].name == "OuterClass"
assert fn.parents[1].name == "InnerClass"
elif fn.function_name == "module_level_async":
assert len(fn.parents) == 0
assert fn.qualified_name == "module_level_async"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,287 @@
from __future__ import annotations
import asyncio
import os
import sqlite3
import sys
import tempfile
from pathlib import Path
import pytest
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import (
codeflash_behavior_async,
codeflash_performance_async,
)
from codeflash.verification.codeflash_capture import VerificationType
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestAsyncWrapperSQLiteValidation:
@pytest.fixture
def test_env_setup(self, request):
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "0",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.fixture
def temp_db_path(self, test_env_setup):
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
yield db_path
if db_path.exists():
db_path.unlink()
@pytest.mark.asyncio
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def simple_async_add(a: int, b: int) -> int:
await asyncio.sleep(0.001)
return a + b
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
result = await simple_async_add(5, 3)
assert result == 8
assert temp_db_path.exists()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
assert cur.fetchone() is not None
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1
row = rows[0]
(test_module_path, test_class_name, test_function_name, function_getting_tested,
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
assert test_module_path == __name__
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
assert test_function_name == "test_behavior_async_basic_function"
assert function_getting_tested == "simple_async_add"
assert loop_index == 1
# Line ID will be the actual line number from the source code, not a simple counter
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
assert runtime > 0
assert verification_type == VerificationType.FUNCTION_CALL.value
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, return_val = unpickled_data
assert args == (5, 3)
assert kwargs == {}
assert return_val == 8
con.close()
@pytest.mark.asyncio
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_divide(a: int, b: int) -> float:
await asyncio.sleep(0.001)
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
result = await async_divide(10, 2)
assert result == 5.0
with pytest.raises(ValueError, match="Cannot divide by zero"):
await async_divide(10, 0)
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 2
success_row = rows[0]
success_data = pickle.loads(success_row[7]) # return_value_blob
args, kwargs, return_val = success_data
assert args == (10, 2)
assert return_val == 5.0
# Check exception record
exception_row = rows[1]
exception_data = pickle.loads(exception_row[7]) # return_value_blob
assert isinstance(exception_data, ValueError)
assert str(exception_data) == "Cannot divide by zero"
con.close()
@pytest.mark.asyncio
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
"""Test performance async decorator doesn't store to database."""
@codeflash_performance_async
async def async_multiply(a: int, b: int) -> int:
"""Async function for performance testing."""
await asyncio.sleep(0.002)
return a * b
result = await async_multiply(4, 7)
assert result == 28
assert not temp_db_path.exists()
captured = capsys.readouterr()
output_lines = captured.out.strip().split('\n')
assert len([line for line in output_lines if "!$######" in line]) == 1
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
assert "async_multiply" in closing_tag
timing_part = closing_tag.split(":")[-1].replace("######!", "")
timing_value = int(timing_part)
assert timing_value > 0 # Should have positive timing
@pytest.mark.asyncio
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def async_increment(value: int) -> int:
await asyncio.sleep(0.001)
return value + 1
# Call the function multiple times
results = []
for i in range(3):
result = await async_increment(i)
results.append(result)
assert results == [1, 2, 3]
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
rows = cur.fetchall()
assert len(rows) == 3
actual_ids = [row[0] for row in rows]
assert len(actual_ids) == 3
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
assert actual_ids == expected_pattern
for i, (_, return_value_blob) in enumerate(rows):
args, kwargs, return_val = pickle.loads(return_value_blob)
assert args == (i,)
assert return_val == i + 1
con.close()
@pytest.mark.asyncio
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def complex_async_func(
pos_arg: str,
*args: int,
keyword_arg: str = "default",
**kwargs: str
) -> dict:
await asyncio.sleep(0.001)
return {
"pos_arg": pos_arg,
"args": args,
"keyword_arg": keyword_arg,
"kwargs": kwargs,
}
result = await complex_async_func(
"hello",
1, 2, 3,
keyword_arg="custom",
extra1="value1",
extra2="value2"
)
expected_result = {
"pos_arg": "hello",
"args": (1, 2, 3),
"keyword_arg": "custom",
"kwargs": {"extra1": "value1", "extra2": "value2"}
}
assert result == expected_result
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM test_results")
row = cur.fetchone()
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
assert stored_args == ("hello", 1, 2, 3)
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
assert stored_result == expected_result
con.close()
@pytest.mark.asyncio
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
@codeflash_behavior_async
async def schema_test_func() -> str:
return "schema_test"
await schema_test_func()
con = sqlite3.connect(temp_db_path)
cur = con.cursor()
cur.execute("PRAGMA table_info(test_results)")
columns = cur.fetchall()
expected_columns = [
(0, 'test_module_path', 'TEXT', 0, None, 0),
(1, 'test_class_name', 'TEXT', 0, None, 0),
(2, 'test_function_name', 'TEXT', 0, None, 0),
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
(4, 'loop_index', 'INTEGER', 0, None, 0),
(5, 'iteration_id', 'TEXT', 0, None, 0),
(6, 'runtime', 'INTEGER', 0, None, 0),
(7, 'return_value', 'BLOB', 0, None, 0),
(8, 'verification_type', 'TEXT', 0, None, 0)
]
assert columns == expected_columns
con.close()

File diff suppressed because it is too large Load diff

View file

@ -12,6 +12,7 @@ from codeflash.code_utils.code_replacer import (
is_zero_diff,
replace_functions_and_add_imports,
replace_functions_in_file,
OptimFunctionCollector,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
@ -3449,156 +3450,173 @@ def hydrate_input_text_actions_with_field_names(
assert new_code == expected
def test_duplicate_global_assignments_when_reverting_helpers():
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# -- all detectors need to be called to update state and avoid double counting
# -- boundaries that happen to coincide, like Table and new section on same element.
# -- Using `any()` would short-circuit on first True.
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
return any(semantic_boundaries)
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
from __future__ import annotations
from typing import Iterable
from unstructured.documents.elements import Element
from unstructured.utils import lazyproperty
class PreChunker:
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
```
'''
# OptimFunctionCollector async function tests
def test_optim_function_collector_with_async_functions():
"""Test OptimFunctionCollector correctly collects async functions."""
import libcst as cst
source_code = """
def sync_function():
return "sync"
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir,
project_root_path=root_dir,
test_framework="pytest",
pytest_cmd="pytest",
async def async_function():
return "async"
class TestClass:
def sync_method(self):
return "sync_method"
async def async_method(self):
return "async_method"
"""
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
preexisting_objects=None
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
tree.visit(collector)
# Should collect both sync and async functions
assert len(collector.modified_functions) == 4
assert (None, "sync_function") in collector.modified_functions
assert (None, "async_function") in collector.modified_functions
assert ("TestClass", "sync_method") in collector.modified_functions
assert ("TestClass", "async_method") in collector.modified_functions
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
def test_optim_function_collector_new_async_functions():
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
import libcst as cst
source_code = """
def existing_function():
return "existing"
async def new_async_function():
return "new_async"
def new_sync_function():
return "new_sync"
class ExistingClass:
async def new_class_async_method(self):
return "new_class_async"
"""
# Only existing_function is in preexisting objects
preexisting_objects = {("existing_function", ())}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=set(), # Not looking for specific functions
preexisting_objects=preexisting_objects
)
tree.visit(collector)
# Should identify new functions (both sync and async)
assert len(collector.new_functions) == 2
function_names = [func.name.value for func in collector.new_functions]
assert "new_async_function" in function_names
assert "new_sync_function" in function_names
# Should identify new class methods
assert "ExistingClass" in collector.new_class_functions
assert len(collector.new_class_functions["ExistingClass"]) == 1
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
new_code = main_file.read_text(encoding="utf-8")
main_file.unlink(missing_ok=True)
def test_optim_function_collector_mixed_scenarios():
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
import libcst as cst
source_code = """
# Global functions
def global_sync():
pass
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
async def global_async():
pass
class ParentClass:
def __init__(self):
pass
def sync_method(self):
pass
async def async_method(self):
pass
class ChildClass:
async def child_async_method(self):
pass
def child_sync_method(self):
pass
"""
# Looking for specific functions
function_names = {
(None, "global_sync"),
(None, "global_async"),
("ParentClass", "sync_method"),
("ParentClass", "async_method"),
("ChildClass", "child_async_method")
}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=function_names,
preexisting_objects=None
)
tree.visit(collector)
# Should collect all specified functions (mix of sync and async)
assert len(collector.modified_functions) == 5
assert (None, "global_sync") in collector.modified_functions
assert (None, "global_async") in collector.modified_functions
assert ("ParentClass", "sync_method") in collector.modified_functions
assert ("ParentClass", "async_method") in collector.modified_functions
assert ("ChildClass", "child_async_method") in collector.modified_functions
# Should collect __init__ method
assert "ParentClass" in collector.modified_init_functions
def test_is_zero_diff_async_sleep():
original_code = '''
import time
async def task():
time.sleep(1)
return "done"
'''
assert new_code == expected
optimized_code = '''
import asyncio
async def task():
await asyncio.sleep(1)
return "done"
'''
assert not is_zero_diff(original_code, optimized_code)
def test_is_zero_diff_with_equivalent_code():
original_code = '''
import asyncio
async def task():
await asyncio.sleep(1)
return "done"
'''
optimized_code = '''
import asyncio
async def task():
"""A task that does something."""
await asyncio.sleep(1)
return "done"
'''
assert is_zero_diff(original_code, optimized_code)

View file

@ -17,10 +17,10 @@ from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
has_any_async_functions,
validate_python_code,
)
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
@pytest.fixture
@ -254,7 +254,7 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
site_packages = [Path("/usr/local/lib/python3.9/site-packages").resolve()]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
@ -277,6 +277,66 @@ def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.Mo
assert path_belongs_to_site_packages(file_path) is False
def test_path_belongs_to_site_packages_with_symlinked_site_packages(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
real_site_packages = tmp_path / "real_site_packages"
real_site_packages.mkdir()
symlinked_site_packages = tmp_path / "symlinked_site_packages"
symlinked_site_packages.symlink_to(real_site_packages)
package_file = real_site_packages / "some_package" / "__init__.py"
package_file.parent.mkdir()
package_file.write_text("# package file")
monkeypatch.setattr(site, "getsitepackages", lambda: [str(symlinked_site_packages)])
assert path_belongs_to_site_packages(package_file) is True
symlinked_package_file = symlinked_site_packages / "some_package" / "__init__.py"
assert path_belongs_to_site_packages(symlinked_package_file) is True
def test_path_belongs_to_site_packages_with_complex_symlinks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
real_site_packages = tmp_path / "real" / "lib" / "python3.9" / "site-packages"
real_site_packages.mkdir(parents=True)
link1 = tmp_path / "link1"
link1.symlink_to(real_site_packages.parent.parent.parent)
link2 = tmp_path / "link2"
link2.symlink_to(link1)
package_file = real_site_packages / "test_package" / "module.py"
package_file.parent.mkdir()
package_file.write_text("# test module")
site_packages_via_links = link2 / "lib" / "python3.9" / "site-packages"
monkeypatch.setattr(site, "getsitepackages", lambda: [str(site_packages_via_links)])
assert path_belongs_to_site_packages(package_file) is True
file_via_links = site_packages_via_links / "test_package" / "module.py"
assert path_belongs_to_site_packages(file_via_links) is True
def test_path_belongs_to_site_packages_resolved_paths_normalization(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
site_packages_dir = tmp_path / "lib" / "python3.9" / "site-packages"
site_packages_dir.mkdir(parents=True)
package_dir = site_packages_dir / "mypackage"
package_dir.mkdir()
package_file = package_dir / "module.py"
package_file.write_text("# module")
complex_site_packages_path = tmp_path / "lib" / "python3.9" / "other" / ".." / "site-packages" / "."
monkeypatch.setattr(site, "getsitepackages", lambda: [str(complex_site_packages_path)])
assert path_belongs_to_site_packages(package_file) is True
complex_file_path = tmp_path / "lib" / "python3.9" / "site-packages" / "other" / ".." / "mypackage" / "module.py"
assert path_belongs_to_site_packages(complex_file_path) is True
# tests for is_class_defined_in_file
def test_is_class_defined_in_file_with_existing_class(tmp_path: Path) -> None:
test_file = tmp_path / "test_file.py"
@ -308,6 +368,86 @@ def my_function():
assert is_class_defined_in_file("MyClass", test_file) is False
@pytest.fixture
def mock_code_context():
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
from unittest.mock import MagicMock
from codeflash.models.models import CodeOptimizationContext
context = MagicMock(spec=CodeOptimizationContext)
context.preexisting_objects = []
return context
def test_extract_dependent_function_sync_and_async(mock_code_context):
"""Test extract_dependent_function with both sync and async functions."""
# Test sync function extraction
mock_code_context.testgen_context_code = """
def main_function():
pass
def helper_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
# Test async function extraction
mock_code_context.testgen_context_code = """
def main_function():
pass
async def async_helper_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
def test_extract_dependent_function_edge_cases(mock_code_context):
"""Test extract_dependent_function edge cases."""
# No dependent functions
mock_code_context.testgen_context_code = """
def main_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) is False
# Multiple dependent functions
mock_code_context.testgen_context_code = """
def main_function():
pass
def helper1():
pass
async def helper2():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) is False
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
"""Test extract_dependent_function with mixed sync/async scenarios."""
# Async main with sync helper
mock_code_context.testgen_context_code = """
async def async_main():
pass
def sync_helper():
pass
"""
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
# Only async functions
mock_code_context.testgen_context_code = """
async def async_main():
pass
async def async_helper():
pass
"""
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
def test_is_class_defined_in_file_with_non_existing_file() -> None:
non_existing_file = Path("/non/existing/file.py")
@ -445,25 +585,41 @@ def test_Grammar_copy():
assert cleaned_code == expected_cleaned_code.strip()
def test_has_any_async_functions_with_async_code() -> None:
def test_validate_python_code_valid() -> None:
code = "def hello():\n return 'world'"
result = validate_python_code(code)
assert result == code
def test_validate_python_code_invalid() -> None:
code = "def hello(:\n return 'world'"
with pytest.raises(ValueError, match="Invalid Python code"):
validate_python_code(code)
def test_validate_python_code_empty() -> None:
code = ""
result = validate_python_code(code)
assert result == code
def test_validate_python_code_complex_invalid() -> None:
code = "if True\n print('missing colon')"
with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"):
validate_python_code(code)
def test_validate_python_code_valid_complex() -> None:
code = """
def normal_function():
pass
async def async_function():
pass
def calculate(a, b):
if a > b:
return a + b
else:
return a * b
class MyClass:
def __init__(self):
self.value = 42
"""
result = has_any_async_functions(code)
assert result is True
def test_has_any_async_functions_without_async_code() -> None:
code = """
def normal_function():
pass
def another_function():
pass
"""
result = has_any_async_functions(code)
assert result is False
result = validate_python_code(code)
assert result == code

View file

@ -42,7 +42,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
"""
test_file_name = "test_stack_info_temp.py"
@ -54,7 +54,7 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
)
assert not result.stderr
assert result.returncode == 0
@ -117,7 +117,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
"""
test_file_name = "test_stack_info_temp.py"
@ -129,7 +129,7 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
)
assert not result.stderr
assert result.returncode == 0
@ -181,7 +181,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
test_file_name = "test_stack_info_temp.py"
@ -194,7 +194,7 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
)
assert not result.stderr
assert result.returncode == 0
@ -261,7 +261,7 @@ class MyClass:
def __init__(self):
self.x = 2
# Print out the detected test info each time we instantiate MyClass
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
"""
test_file_name = "test_stack_info_recursive_temp.py"
@ -279,7 +279,7 @@ class MyClass:
# Run pytest as a subprocess
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
)
# Check for errors
@ -343,7 +343,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
test_file_name = "test_stack_info_temp.py"
@ -356,7 +356,7 @@ class MyClass:
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
)
assert not result.stderr
assert result.returncode == 0
@ -410,10 +410,11 @@ class TestUnittestExample(unittest.TestCase):
self.assertTrue(True)
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
sample_code = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name="some_function", tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", tests_root="{test_dir!s}")
@codeflash_capture(function_name="some_function", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}")
def __init__(self, x=2):
self.x = x
"""
@ -528,6 +529,7 @@ class TestUnittestExample(unittest.TestCase):
self.assertTrue(True)
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
# MyClass did not have an init function, we created the init function with the codeflash_capture decorator using instrumentation
sample_code = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
@ -536,7 +538,7 @@ class ParentClass:
self.x = 2
class MyClass(ParentClass):
@codeflash_capture(function_name="some_function", tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", tests_root="{test_dir!s}")
@codeflash_capture(function_name="some_function", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}")
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
"""
@ -648,14 +650,15 @@ def test_example_test():
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
sample_code = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(
function_name="some_function",
tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}",
tests_root="{test_dir!s}"
tmp_dir_path="{tmp_dir_path.as_posix()}",
tests_root="{test_dir.as_posix()}"
)
def __init__(self, x=2):
self.x = x
@ -765,13 +768,14 @@ def test_helper_classes():
assert MyClass().target_function() == 6
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
original_code = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
from code_to_optimize.tests.pytest.helper_file_1 import HelperClass1
from code_to_optimize.tests.pytest.helper_file_2 import HelperClass2, AnotherHelperClass
class MyClass:
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}" , is_fto=True)
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}" , is_fto=True)
def __init__(self):
self.x = 1
@ -785,7 +789,7 @@ class MyClass:
from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass1:
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False)
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False)
def __init__(self):
self.y = 1
@ -797,7 +801,7 @@ class HelperClass1:
from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass2:
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False)
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False)
def __init__(self):
self.z = 2
@ -805,7 +809,7 @@ class HelperClass2:
return 2
class AnotherHelperClass:
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False)
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View file

@ -787,6 +787,126 @@ def test_jax():
assert not comparator(aa, cc)
def test_xarray():
try:
import xarray as xr
import numpy as np
except ImportError:
pytest.skip()
# Test basic DataArray
a = xr.DataArray([1, 2, 3], dims=['x'])
b = xr.DataArray([1, 2, 3], dims=['x'])
c = xr.DataArray([1, 2, 4], dims=['x'])
assert comparator(a, b)
assert not comparator(a, c)
# Test DataArray with coordinates
d = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
e = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
f = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 3]}, dims=['x'])
assert comparator(d, e)
assert not comparator(d, f)
# Test DataArray with attributes
g = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
h = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
i = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'feet'})
assert comparator(g, h)
assert not comparator(g, i)
# Test 2D DataArray
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=['x', 'y'])
assert comparator(j, k)
assert not comparator(j, l)
# Test DataArray with different dimensions
m = xr.DataArray([1, 2, 3], dims=['x'])
n = xr.DataArray([1, 2, 3], dims=['y'])
assert not comparator(m, n)
# Test DataArray with NaN values
o = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
p = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
q = xr.DataArray([1.0, 2.0, 3.0], dims=['x'])
assert comparator(o, p)
assert not comparator(o, q)
# Test Dataset
r = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
})
s = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
})
t = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
'pressure': (['x', 'y'], [[5, 6], [7, 9]])
})
assert comparator(r, s)
assert not comparator(r, t)
# Test Dataset with coordinates
u = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]])
}, coords={'x': [0, 1], 'y': [0, 1]})
v = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]])
}, coords={'x': [0, 1], 'y': [0, 1]})
w = xr.Dataset({
'temp': (['x', 'y'], [[1, 2], [3, 4]])
}, coords={'x': [0, 2], 'y': [0, 1]})
assert comparator(u, v)
assert not comparator(u, w)
# Test Dataset with attributes
x = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
y = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
z = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'model'})
assert comparator(x, y)
assert not comparator(x, z)
# Test Dataset with different variables
aa = xr.Dataset({'temp': (['x'], [1, 2, 3])})
bb = xr.Dataset({'temp': (['x'], [1, 2, 3])})
cc = xr.Dataset({'pressure': (['x'], [1, 2, 3])})
assert comparator(aa, bb)
assert not comparator(aa, cc)
# Test empty Dataset
dd = xr.Dataset()
ee = xr.Dataset()
assert comparator(dd, ee)
# Test DataArray with different shapes
ff = xr.DataArray([1, 2, 3], dims=['x'])
gg = xr.DataArray([[1, 2, 3]], dims=['x', 'y'])
assert not comparator(ff, gg)
# Test DataArray with different data types
# Note: xarray.identical() considers int and float arrays with same values as identical
hh = xr.DataArray(np.array([1, 2, 3], dtype='int32'), dims=['x'])
ii = xr.DataArray(np.array([1, 2, 3], dtype='int64'), dims=['x'])
# xarray is permissive with dtype comparisons, treats these as identical
assert comparator(hh, ii)
# Test DataArray with infinity
jj = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
kk = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=['x'])
assert comparator(jj, kk)
assert not comparator(jj, ll)
# Test Dataset vs DataArray (different types)
mm = xr.DataArray([1, 2, 3], dims=['x'])
nn = xr.Dataset({'data': (['x'], [1, 2, 3])})
assert not comparator(mm, nn)
def test_returns():
a = Success(5)
b = Success(5)
@ -1502,4 +1622,147 @@ def test_collections() -> None:
d = "hello"
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, d)
def test_attrs():
try:
import attrs # type: ignore
except ImportError:
pytest.skip()
@attrs.define
class Person:
name: str
age: int = 10
a = Person("Alice", 25)
b = Person("Alice", 25)
c = Person("Bob", 25)
d = Person("Alice", 30)
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
@attrs.frozen
class Point:
x: int
y: int
p1 = Point(1, 2)
p2 = Point(1, 2)
p3 = Point(2, 3)
assert comparator(p1, p2)
assert not comparator(p1, p3)
@attrs.define(slots=True)
class Vehicle:
brand: str
model: str
year: int = 2020
v1 = Vehicle("Toyota", "Camry", 2021)
v2 = Vehicle("Toyota", "Camry", 2021)
v3 = Vehicle("Honda", "Civic", 2021)
assert comparator(v1, v2)
assert not comparator(v1, v3)
@attrs.define
class ComplexClass:
public_field: str
private_field: str = attrs.field(repr=False)
non_eq_field: int = attrs.field(eq=False, default=0)
computed: str = attrs.field(init=False, eq=True)
def __attrs_post_init__(self):
self.computed = f"{self.public_field}_{self.private_field}"
c1 = ComplexClass("test", "secret")
c2 = ComplexClass("test", "secret")
c3 = ComplexClass("different", "secret")
c1.non_eq_field = 100
c2.non_eq_field = 200
assert comparator(c1, c2)
assert not comparator(c1, c3)
@attrs.define
class Address:
street: str
city: str
@attrs.define
class PersonWithAddress:
name: str
address: Address
addr1 = Address("123 Main St", "Anytown")
addr2 = Address("123 Main St", "Anytown")
addr3 = Address("456 Oak Ave", "Anytown")
person1 = PersonWithAddress("John", addr1)
person2 = PersonWithAddress("John", addr2)
person3 = PersonWithAddress("John", addr3)
assert comparator(person1, person2)
assert not comparator(person1, person3)
@attrs.define
class Container:
items: list
metadata: dict
cont1 = Container([1, 2, 3], {"type": "numbers"})
cont2 = Container([1, 2, 3], {"type": "numbers"})
cont3 = Container([1, 2, 4], {"type": "numbers"})
assert comparator(cont1, cont2)
assert not comparator(cont1, cont3)
@attrs.define
class BaseClass:
name: str
value: int
@attrs.define
class ExtendedClass:
name: str
value: int
extra_field: str = "default"
base = BaseClass("test", 42)
extended = ExtendedClass("test", 42, "extra")
assert not comparator(base, extended)
@attrs.define
class WithNonEqFields:
name: str
timestamp: float = attrs.field(eq=False) # Should be ignored
debug_info: str = attrs.field(eq=False, default="debug")
obj1 = WithNonEqFields("test", 1000.0, "info1")
obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields
obj3 = WithNonEqFields("different", 1000.0, "info1")
assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info
assert not comparator(obj1, obj3) # Should be different due to name
@attrs.define
class MinimalClass:
name: str
value: int
@attrs.define
class ExtendedClass:
name: str
value: int
extra_field: str = "default"
metadata: dict = attrs.field(factory=dict)
timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored
minimal = MinimalClass("test", 42)
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
assert not comparator(minimal, extended)

View file

@ -14,7 +14,13 @@ from codeflash.models.models import (
TestResults,
TestType,
)
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.critic import (
coverage_critic,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
throughput_gain,
)
def test_performance_gain() -> None:
@ -429,3 +435,159 @@ def test_coverage_critic() -> None:
)
assert coverage_critic(unittest_coverage, "unittest") is True
def test_throughput_gain() -> None:
"""Test throughput_gain calculation."""
# Test basic throughput improvement
assert throughput_gain(original_throughput=100, optimized_throughput=150) == 0.5 # 50% improvement
# Test no improvement
assert throughput_gain(original_throughput=100, optimized_throughput=100) == 0.0
# Test regression
assert throughput_gain(original_throughput=100, optimized_throughput=80) == -0.2 # 20% regression
# Test zero original throughput (edge case)
assert throughput_gain(original_throughput=0, optimized_throughput=50) == 0.0
# Test large improvement
assert throughput_gain(original_throughput=50, optimized_throughput=200) == 3.0 # 300% improvement
def test_speedup_critic_with_async_throughput() -> None:
"""Test speedup_critic with async throughput evaluation."""
original_code_runtime = 10000 # 10 microseconds
original_async_throughput = 100
# Test case 1: Both runtime and throughput improve significantly
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=120, # 20% throughput improvement
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=105, # Only 5% throughput improvement (below 10% threshold)
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=9800, # Only 2% runtime improvement (below 5% threshold)
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=9800,
async_throughput=120, # 20% throughput improvement
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 4: No throughput data - should fall back to runtime-only evaluation
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=None, # No throughput data
)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=None, # No original throughput data
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Test case 5: Test best_throughput_until_now comparison
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=115, # 15% throughput improvement
)
# Should pass when no best throughput yet
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
disable_gh_action_noise=True
)
# Should fail when there's a better throughput already
assert not speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=7000, # Better runtime already exists
original_async_throughput=original_async_throughput,
best_throughput_until_now=120, # Better throughput already exists
disable_gh_action_noise=True
)
# Test case 6: Zero original throughput (edge case)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=50,
)
# Should pass when original throughput is 0 (throughput evaluation skipped)
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=0, # Zero original throughput
best_throughput_until_now=None,
disable_gh_action_noise=True
)

View file

@ -14,6 +14,11 @@ from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
def test_remove_duplicate_imports():
"""Test that duplicate imports are removed when should_sort_imports is True."""
original_code = "import os\nimport os\n"
@ -37,17 +42,15 @@ def test_sorting_imports():
assert new_code == "import os\nimport sys\nimport unittest\n"
def test_sort_imports_without_formatting():
def test_sort_imports_without_formatting(temp_dir):
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(b"import sys\nimport unittest\nimport os\n")
tmp.flush()
tmp_path = Path(tmp.name)
temp_file = temp_dir / "test_file.py"
temp_file.write_text("import sys\nimport unittest\nimport os\n")
new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
assert new_code is not None
new_code = sort_imports(new_code)
assert new_code == "import os\nimport sys\nimport unittest\n"
new_code = format_code(formatter_cmds=["disabled"], path=temp_file)
assert new_code is not None
new_code = sort_imports(new_code)
assert new_code == "import os\nimport sys\nimport unittest\n"
def test_dedup_and_sort_imports_deduplicates():
@ -101,7 +104,7 @@ def foo():
assert actual == expected
def test_formatter_cmds_non_existent():
def test_formatter_cmds_non_existent(temp_dir):
"""Test that default formatter-cmds is used when it doesn't exist in the toml."""
config_data = """
[tool.codeflash]
@ -110,113 +113,99 @@ tests-root = "tests"
test-framework = "pytest"
ignore-paths = []
"""
config_file = temp_dir / "config.toml"
config_file.write_text(config_data)
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp:
tmp.write(config_data.encode())
tmp.flush()
tmp_path = Path(tmp.name)
try:
config, _ = parse_config_file(tmp_path)
assert config["formatter_cmds"] == ["black $file"]
finally:
os.remove(tmp_path)
config, _ = parse_config_file(config_file)
assert config["formatter_cmds"] == ["black $file"]
try:
import black
except ImportError:
pytest.skip("black is not installed")
original_code = b"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], "bar")
"""
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
assert actual == expected
def test_formatter_black():
try:
import black
except ImportError:
pytest.skip("black is not installed")
original_code = b"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], "bar")
"""
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
assert actual == expected
def test_formatter_ruff():
try:
import ruff # type: ignore
except ImportError:
pytest.skip("ruff is not installed")
original_code = b"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], "bar")
"""
with tempfile.NamedTemporaryFile(suffix=".py") as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=Path(tmp_path)
)
assert actual == expected
def test_formatter_error():
original_code = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
with tempfile.NamedTemporaryFile("w") as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
try:
new_code = format_code(formatter_cmds=["exit 1"], path=Path(tmp_path), exit_on_failure=False)
assert new_code == original_code
except Exception as e:
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], \"bar\")
"""
temp_file = temp_dir / "test_file.py"
temp_file.write_text(original_code)
actual = format_code(formatter_cmds=["black $file"], path=temp_file)
assert actual == expected
def test_formatter_black(temp_dir):
try:
import black
except ImportError:
pytest.skip("black is not installed")
original_code = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], \"bar\")
"""
temp_file = temp_dir / "test_file.py"
temp_file.write_text(original_code)
actual = format_code(formatter_cmds=["black $file"], path=temp_file)
assert actual == expected
def test_formatter_ruff(temp_dir):
try:
import ruff # type: ignore
except ImportError:
pytest.skip("ruff is not installed")
original_code = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], \"bar\")
"""
temp_file = temp_dir / "test_file.py"
temp_file.write_text(original_code)
actual = format_code(
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file
)
assert actual == expected
def test_formatter_error(tmp_path: Path):
original_code = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
temp_file = tmp_path / "test_formatter_error.py"
temp_file.write_text(original_code, encoding="utf-8")
try:
new_code = format_code(formatter_cmds=["exit 1"], path=temp_file, exit_on_failure=False)
assert new_code == original_code
except Exception as e:
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):

View file

@ -21,11 +21,15 @@ def test_function_eligible_for_optimization() -> None:
return a**2
"""
functions_found = {}
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(function)
f.flush()
functions_found = find_all_functions_in_file(Path(f.name))
assert functions_found[Path(f.name)][0].function_name == "test_function_eligible_for_optimization"
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(function)
functions_found = find_all_functions_in_file(file_path)
assert functions_found[file_path][0].function_name == "test_function_eligible_for_optimization"
# Has no return statement
function = """def test_function_not_eligible_for_optimization():
@ -33,28 +37,40 @@ def test_function_eligible_for_optimization() -> None:
print(a)
"""
functions_found = {}
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(function)
f.flush()
functions_found = find_all_functions_in_file(Path(f.name))
assert len(functions_found[Path(f.name)]) == 0
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(function)
functions_found = find_all_functions_in_file(file_path)
assert len(functions_found[file_path]) == 0
# we want to trigger an error in the function discovery
function = """def test_invalid_code():"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(function)
f.flush()
functions_found = find_all_functions_in_file(Path(f.name))
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(function)
functions_found = find_all_functions_in_file(file_path)
assert functions_found == {}
def test_find_top_level_function_or_method():
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""def functionA():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""def functionA():
def functionB():
return 5
class E:
@ -76,42 +92,48 @@ class AirbyteEntrypoint(object):
def non_classmethod_function(cls, name):
return cls.name
"""
)
f.flush()
path_obj_name = Path(f.name)
assert inspect_top_level_functions_or_methods(path_obj_name, "functionA").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionB").is_top_level
assert inspect_top_level_functions_or_methods(path_obj_name, "functionC", class_name="A").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionD", class_name="A").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionF", class_name="E").is_top_level
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA").has_args
)
assert inspect_top_level_functions_or_methods(file_path, "functionA").is_top_level
assert not inspect_top_level_functions_or_methods(file_path, "functionB").is_top_level
assert inspect_top_level_functions_or_methods(file_path, "functionC", class_name="A").is_top_level
assert not inspect_top_level_functions_or_methods(file_path, "functionD", class_name="A").is_top_level
assert not inspect_top_level_functions_or_methods(file_path, "functionF", class_name="E").is_top_level
assert not inspect_top_level_functions_or_methods(file_path, "functionA").has_args
staticmethod_func = inspect_top_level_functions_or_methods(
path_obj_name, "handle_record_counts", class_name=None, line_no=15
file_path, "handle_record_counts", class_name=None, line_no=15
)
assert staticmethod_func.is_staticmethod
assert staticmethod_func.staticmethod_class_name == "AirbyteEntrypoint"
assert inspect_top_level_functions_or_methods(
path_obj_name, "functionE", class_name="AirbyteEntrypoint"
file_path, "functionE", class_name="AirbyteEntrypoint"
).is_classmethod
assert not inspect_top_level_functions_or_methods(
path_obj_name, "non_classmethod_function", class_name="AirbyteEntrypoint"
file_path, "non_classmethod_function", class_name="AirbyteEntrypoint"
).is_top_level
# needed because this will be traced with a class_name being passed
# we want to write invalid code to ensure that the function discovery does not crash
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""def functionA():
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""def functionA():
"""
)
f.flush()
path_obj_name = Path(f.name)
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA")
)
assert not inspect_top_level_functions_or_methods(file_path, "functionA")
def test_class_method_discovery():
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""class A:
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""class A:
def functionA():
return True
def functionB():
@ -123,21 +145,20 @@ class X:
return False
def functionA():
return True"""
)
f.flush()
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
file=file_path,
only_get_this_function="A.functionA",
test_cfg=test_config,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
project_root=file_path.parent,
module_root=file_path.parent,
)
assert len(functions) == 1
for file in functions:
@ -148,12 +169,12 @@ def functionA():
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
file=file_path,
only_get_this_function="X.functionA",
test_cfg=test_config,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
project_root=file_path.parent,
module_root=file_path.parent,
)
assert len(functions) == 1
for file in functions:
@ -164,12 +185,12 @@ def functionA():
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
file=file_path,
only_get_this_function="functionA",
test_cfg=test_config,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
project_root=file_path.parent,
module_root=file_path.parent,
)
assert len(functions) == 1
for file in functions:
@ -178,8 +199,12 @@ def functionA():
def test_nested_function():
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""
import copy
@ -223,28 +248,31 @@ def propagate_attributes(
traverse(source_node_id)
return modified_nodes
"""
)
f.flush()
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
file=file_path,
test_cfg=test_config,
only_get_this_function=None,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
project_root=file_path.parent,
module_root=file_path.parent,
)
assert len(functions) == 1
assert functions_count == 1
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""
def outer_function():
def inner_function():
@ -252,28 +280,31 @@ def outer_function():
return inner_function
"""
)
f.flush()
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
file=file_path,
test_cfg=test_config,
only_get_this_function=None,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
project_root=file_path.parent,
module_root=file_path.parent,
)
assert len(functions) == 1
assert functions_count == 1
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
file_path = temp_dir_path / "test_function.py"
with file_path.open("w") as f:
f.write(
"""
def outer_function():
def inner_function():
@ -283,21 +314,20 @@ def outer_function():
pass
return inner_function, another_inner_function
"""
)
f.flush()
)
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count, _ = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
file=file_path,
test_cfg=test_config,
only_get_this_function=None,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
project_root=file_path.parent,
module_root=file_path.parent,
)
assert len(functions) == 1

View file

@ -3,13 +3,20 @@ import tempfile
from codeflash.code_utils.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
import pytest
from pathlib import Path
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
def test_get_code_function() -> None:
def test_get_code_function(temp_dir: Path) -> None:
code = """def test(self):
return self._test"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
@ -18,14 +25,14 @@ def test_get_code_function() -> None:
assert contextual_dunder_methods == set()
def test_get_code_property() -> None:
def test_get_code_property(temp_dir: Path) -> None:
code = """class TestClass:
def __init__(self):
self._test = 5
@property
def test(self):
return self._test"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
@ -36,7 +43,7 @@ def test_get_code_property() -> None:
assert contextual_dunder_methods == {("TestClass", "__init__")}
def test_get_code_class() -> None:
def test_get_code_class(temp_dir: Path) -> None:
code = """
class TestClass:
def __init__(self):
@ -54,7 +61,7 @@ class TestClass:
@property
def test(self):
return self._test"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
@ -65,7 +72,7 @@ class TestClass:
assert contextual_dunder_methods == {("TestClass", "__init__")}
def test_get_code_bubble_sort_class() -> None:
def test_get_code_bubble_sort_class(temp_dir: Path) -> None:
code = """
def hi():
pass
@ -105,7 +112,7 @@ class BubbleSortClass:
arr[j + 1] = temp
return arr
"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
@ -116,7 +123,7 @@ class BubbleSortClass:
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
def test_get_code_indent() -> None:
def test_get_code_indent(temp_dir: Path) -> None:
code = """def hi():
pass
@ -168,7 +175,7 @@ def non():
def helper(self, arr, j):
return arr[j] > arr[j + 1]
"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
new_code, contextual_dunder_methods = get_code(
@ -198,7 +205,7 @@ def non():
def unsorter(self, arr):
return shuffle(arr)
"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
new_code, contextual_dunder_methods = get_code(
@ -212,7 +219,7 @@ def non():
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
def test_get_code_multiline_class_def() -> None:
def test_get_code_multiline_class_def(temp_dir: Path) -> None:
code = """class StatementAssignmentVariableConstantMutable(
StatementAssignmentVariableMixin, StatementAssignmentVariableConstantMutableBase
):
@ -235,7 +242,7 @@ def test_get_code_multiline_class_def() -> None:
def computeStatement(self, trace_collection):
return self, None, None
"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
@ -252,13 +259,13 @@ def test_get_code_multiline_class_def() -> None:
assert contextual_dunder_methods == set()
def test_get_code_dataclass_attribute():
def test_get_code_dataclass_attribute(temp_dir: Path) -> None:
code = """@dataclass
class CustomDataClass:
name: str = ""
data: List[int] = field(default_factory=list)"""
with tempfile.NamedTemporaryFile("w") as f:
with (temp_dir / "temp_file.py").open(mode="w") as f:
f.write(code)
f.flush()
@ -269,4 +276,4 @@ class CustomDataClass:
[FunctionToOptimize("name", f.name, [FunctionParent("CustomDataClass", "ClassDef")])]
)
assert new_code is None
assert contextual_dunder_methods == set()
assert contextual_dunder_methods == set()

View file

@ -213,11 +213,12 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
lifespan=self.__duration__,
)
'''
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
project_root_path = file_path.parent.resolve()
with tempfile.TemporaryDirectory() as tempdir:
tempdir_path = Path(tempdir)
file_path = (tempdir_path / "typed_code_helper.py").resolve()
file_path.write_text(code, encoding="utf-8")
project_root_path = tempdir_path.resolve()
project_root_path = tempdir_path.resolve()
function_to_optimize = FunctionToOptimize(
function_name="__call__",
file_path=file_path,
@ -440,4 +441,4 @@ def sorter_deps(arr):
code_context.helper_functions[0].fully_qualified_name
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
)
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"

View file

@ -123,7 +123,7 @@ def test_sort():
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
with test_path.open("w") as f:
@ -276,16 +276,16 @@ def test_sort():
fto = FunctionToOptimize(
function_name="sorter", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path)
)
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_test_path = Path(tmpdirname) / "test_class_method_behavior_results_temp.py"
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest"
tmp_test_path, [CodePosition(7, 13), CodePosition(12, 13)], fto, tmp_test_path.parent, "pytest"
)
assert success
assert new_test.replace('"', "'") == expected.format(
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
module_path=tmp_test_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
test_path = tests_root / "test_class_method_behavior_results_temp.py"
@ -295,7 +295,7 @@ def test_sort():
try:
new_test = expected.format(
module_path="code_to_optimize.tests.pytest.test_class_method_behavior_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
)
with test_path.open("w") as f:
@ -486,4 +486,4 @@ class BubbleSorter:
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)

View file

@ -0,0 +1,807 @@
import tempfile
from pathlib import Path
import uuid
import os
import sys
import pytest
from codeflash.code_utils.instrument_existing_tests import (
add_async_decorator_to_function,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
# @pytest.fixture
# def unique_test_iteration():
# """Provide a unique test iteration ID and clean up database after test."""
# # Generate unique iteration ID
# iteration_id = str(uuid.uuid4())[:8]
# # Store original environment variable
# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION")
# # Set unique iteration for this test
# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id
# try:
# yield iteration_id
# finally:
# # Cleanup: restore original environment and delete database file
# if original_iteration is not None:
# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration
# elif "CODEFLASH_TEST_ITERATION" in os.environ:
# del os.environ["CODEFLASH_TEST_ITERATION"]
# # Clean up database file
# try:
# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite"))
# if db_path.exists():
# db_path.unlink()
# except Exception:
# pass # Ignore cleanup errors
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_application_behavior_mode():
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_application_performance_mode():
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_performance_async
@codeflash_performance_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_class_method_decorator_application():
async_class_code = '''
import asyncio
class Calculator:
"""Test class with async methods."""
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class Calculator:
"""Test class with async methods."""
@codeflash_behavior_async
async def async_method(self, a: int, b: int) -> int:
"""Async method in class."""
await asyncio.sleep(0.005)
return a ** b
def sync_method(self, a: int, b: int) -> int:
"""Sync method in class."""
return a - b
'''
func = FunctionToOptimize(
function_name="async_method",
file_path=Path("test_async.py"),
parents=[{"name": "Calculator", "type": "ClassDef"}],
is_async=True,
)
modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR)
assert decorator_added
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_no_duplicate_application():
already_decorated_code = '''
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
import asyncio
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
expected_reformatted_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
@codeflash_behavior_async
async def async_function(x: int, y: int) -> int:
"""Already decorated async function."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR)
assert not decorator_added
assert modified_code.strip() == expected_reformatted_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inject_profiling_async_function_behavior_mode(temp_dir):
source_module_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
async_test_code = '''
import asyncio
import pytest
from my_module import async_function
@pytest.mark.asyncio
async def test_async_function():
"""Test async function behavior."""
result = await async_function(5, 3)
assert result == 15
result2 = await async_function(2, 4)
assert result2 == 8
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_test_code)
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.BEHAVIOR
)
assert source_success is True
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
source_file.write_text(instrumented_source)
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
)
# For async functions, once source is decorated, test injection should fail
# This is expected behavior - async instrumentation happens at the decorator level
assert success is False
assert instrumented_test_code is None
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inject_profiling_async_function_performance_mode(temp_dir):
source_module_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
# Create the test file
async_test_code = '''
import asyncio
import pytest
from my_module import async_function
@pytest.mark.asyncio
async def test_async_function():
"""Test async function performance."""
result = await async_function(5, 3)
assert result == 15
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_test_code)
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
# First instrument the source module
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.PERFORMANCE
)
assert source_success is True
assert instrumented_source is not None
assert "@codeflash_performance_async" in instrumented_source
# Check for the import with line continuation formatting
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_performance_async" in instrumented_source
# Write the instrumented source back
source_file.write_text(instrumented_source)
# Now test the full pipeline with source module path
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE
)
# For async functions, once source is decorated, test injection should fail
# This is expected behavior - async instrumentation happens at the decorator level
assert success is False
assert instrumented_test_code is None
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_mixed_sync_async_instrumentation(temp_dir):
source_module_code = '''
import asyncio
def sync_function(x: int, y: int) -> int:
"""Regular sync function."""
return x * y
async def async_function(x: int, y: int) -> int:
"""Simple async function."""
await asyncio.sleep(0.01)
return x * y
'''
source_file = temp_dir / "my_module.py"
source_file.write_text(source_module_code)
mixed_test_code = '''
import asyncio
import pytest
from my_module import sync_function, async_function
@pytest.mark.asyncio
async def test_mixed_functions():
"""Test both sync and async functions."""
sync_result = sync_function(10, 5)
assert sync_result == 50
async_result = await async_function(3, 4)
assert async_result == 12
'''
test_file = temp_dir / "test_mixed.py"
test_file.write_text(mixed_test_code)
async_func = FunctionToOptimize(
function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True
)
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, async_func, TestingMode.BEHAVIOR
)
assert source_success
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
assert "codeflash_behavior_async" in instrumented_source
# Sync function should remain unchanged
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
# Write instrumented source back
source_file.write_text(instrumented_source)
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file,
[CodePosition(8, 18), CodePosition(11, 19)],
async_func,
temp_dir,
"pytest",
mode=TestingMode.BEHAVIOR,
)
# Async functions should not be instrumented at the test level
assert not success
assert instrumented_test_code is None
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_function_qualified_name_handling():
nested_async_code = '''
import asyncio
class OuterClass:
class InnerClass:
async def nested_async_method(self, x: int) -> int:
"""Nested async method."""
await asyncio.sleep(0.001)
return x * 2
'''
func = FunctionToOptimize(
function_name="nested_async_method",
file_path=Path("test_nested.py"),
parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}],
is_async=True,
)
modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR)
expected_output = (
"""import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_behavior_async
class OuterClass:
class InnerClass:
@codeflash_behavior_async
async def nested_async_method(self, x: int) -> int:
\"\"\"Nested async method.\"\"\"
await asyncio.sleep(0.001)
return x * 2
"""
)
assert modified_code.strip() == expected_output.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_with_existing_decorators():
"""Test async decorator application when function already has other decorators."""
decorated_async_code = '''
import asyncio
from functools import wraps
def my_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
@my_decorator
async def async_function(x: int, y: int) -> int:
"""Async function with existing decorator."""
await asyncio.sleep(0.01)
return x * y
'''
func = FunctionToOptimize(
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
)
modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR)
assert decorator_added
# Should add codeflash decorator above existing decorators
assert "@codeflash_behavior_async" in modified_code
assert "@my_decorator" in modified_code
# Codeflash decorator should come first
codeflash_pos = modified_code.find("@codeflash_behavior_async")
my_decorator_pos = modified_code.find("@my_decorator")
assert codeflash_pos < my_decorator_pos
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_sync_function_not_affected_by_async_logic():
sync_function_code = '''
def sync_function(x: int, y: int) -> int:
"""Regular sync function."""
return x + y
'''
sync_func = FunctionToOptimize(
function_name="sync_function",
file_path=Path("test_sync.py"),
parents=[],
is_async=False,
)
modified_code, decorator_added = add_async_decorator_to_function(
sync_function_code, sync_func, TestingMode.BEHAVIOR
)
assert not decorator_added
assert modified_code == sync_function_code
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
source_module_code = '''
import asyncio
async def async_sorter(items):
"""Simple async sorter for testing."""
await asyncio.sleep(0.001)
return sorted(items)
'''
source_file = temp_dir / "async_sorter.py"
source_file.write_text(source_module_code)
test_code_multiple_calls = """
import asyncio
import pytest
from async_sorter import async_sorter
@pytest.mark.asyncio
async def test_single_call():
result = await async_sorter([42])
assert result == [42]
@pytest.mark.asyncio
async def test_multiple_calls():
result1 = await async_sorter([3, 1, 2])
result2 = await async_sorter([5, 4])
result3 = await async_sorter([9, 8, 7, 6])
assert result1 == [1, 2, 3]
assert result2 == [4, 5]
assert result3 == [6, 7, 8, 9]
"""
test_file = temp_dir / "test_async_sorter.py"
test_file.write_text(test_code_multiple_calls)
func = FunctionToOptimize(
function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True
)
# First instrument the source module with async decorators
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
source_success, instrumented_source = instrument_source_module_with_async_decorators(
source_file, func, TestingMode.BEHAVIOR
)
assert source_success
assert instrumented_source is not None
assert "@codeflash_behavior_async" in instrumented_source
source_file.write_text(instrumented_source)
import ast
tree = ast.parse(test_code_multiple_calls)
call_positions = []
for node in ast.walk(tree):
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or (
hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter"
):
call_positions.append(CodePosition(node.lineno, node.col_offset))
assert len(call_positions) == 4
success, instrumented_test_code = inject_profiling_into_existing_test(
test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
)
assert success
assert instrumented_test_code is not None
assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code
# Count occurrences of each line_id to verify numbering
line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'")
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_behavior_decorator_return_values_and_test_ids():
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
import asyncio
import os
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
@codeflash_behavior_async
async def test_async_multiply(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.001) # Small delay to simulate async work
return x * y
test_env = {
"CODEFLASH_TEST_MODULE": "test_module",
"CODEFLASH_TEST_CLASS": None,
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
"CODEFLASH_CURRENT_LINE_ID": "0",
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_ITERATION": "2",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
result = asyncio.run(test_async_multiply(6, 7))
assert result == 42, f"Expected return value 42, got {result}"
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
# Verify database exists and has data
assert db_path.exists(), f"Database file not created at {db_path}"
# Read and verify database contents
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute("SELECT * FROM test_results")
rows = cur.fetchall()
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
row = rows[0]
(
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) = row
assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'"
assert test_class is None, f"Expected test_class None, got '{test_class}'"
assert test_function == "test_async_multiply_function", (
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
)
assert function_name == "test_async_multiply", (
f"Expected function_name 'test_async_multiply', got '{function_name}'"
)
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'"
assert verification_type == "function_call", (
f"Expected verification_type 'function_call', got '{verification_type}'"
)
unpickled_data = pickle.loads(return_value_blob)
args, kwargs, actual_return_value = unpickled_data
assert args == (6, 7), f"Expected args (6, 7), got {args}"
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}"
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_comprehensive_return_values_and_test_ids():
import asyncio
import os
import sqlite3
from pathlib import Path
import dill as pickle
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file
@codeflash_behavior_async
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
"""Async function that multiplies x*y then adds z."""
await asyncio.sleep(0.001)
result = (x * y) + z
return result
test_env = {
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
"CODEFLASH_CURRENT_LINE_ID": "3",
"CODEFLASH_LOOP_INDEX": "2",
"CODEFLASH_TEST_ITERATION": "3",
}
original_env = {k: os.environ.get(k) for k in test_env}
for k, v in test_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
try:
test_cases = [
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
{"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
]
results = []
for test_case in test_cases:
result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"]))
results.append(result)
# Verify each return value is exactly correct
assert result == test_case["expected"], (
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
)
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
assert db_path.exists(), f"Database not created at {db_path}"
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute(
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
)
rows = cur.fetchall()
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
for i, (
test_module,
test_class,
test_function,
function_name,
loop_index,
iteration_id,
runtime,
return_value_blob,
verification_type,
) in enumerate(rows):
assert test_module == "test_comprehensive_module", (
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
)
assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
assert test_function == "test_comprehensive_async_function", (
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
)
assert function_name == "async_multiply_add", (
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
)
assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}"
assert verification_type == "function_call", (
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
)
expected_iteration_id = f"3_{i}"
assert iteration_id == expected_iteration_id, (
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
)
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
expected_args = test_cases[i]["args"]
expected_kwargs = test_cases[i]["kwargs"]
expected_return = test_cases[i]["expected"]
assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}"
assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
assert actual_return_value == expected_return, (
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
)
con.close()
finally:
for k, v in original_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]

View file

@ -22,7 +22,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
def __init__(self):
self.x = 1
@ -86,7 +86,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass(ParentClass):
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -128,7 +128,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
def __init__(self):
self.x = 1
@ -184,7 +184,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
def __init__(self):
self.x = 1
@ -197,7 +197,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass:
@codeflash_capture(function_name='HelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
@codeflash_capture(function_name='HelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
def __init__(self):
self.y = 1
@ -271,7 +271,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class MyClass:
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
def __init__(self):
self.x = 1
@ -289,7 +289,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass1:
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
def __init__(self):
self.y = 1
@ -304,7 +304,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
class HelperClass2:
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
def __init__(self):
self.z = 2
@ -313,7 +313,7 @@ class HelperClass2:
class AnotherHelperClass:
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

View file

@ -37,7 +37,7 @@ def test_add_decorator_imports_helper_in_class():
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
@ -106,7 +106,7 @@ def test_add_decorator_imports_helper_in_nested_class():
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
@ -151,7 +151,7 @@ def test_add_decorator_imports_nodeps():
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
@codeflash_line_profile
@ -200,7 +200,7 @@ def test_add_decorator_imports_helper_outside():
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
@ -275,7 +275,7 @@ class helper:
line_profiler_output_file = add_decorator_imports(
func_optimizer.function_to_optimize, code_context)
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
@codeflash_line_profile

View file

@ -6,7 +6,7 @@ import os
import sys
import tempfile
from pathlib import Path
import pytest
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.instrument_existing_tests import (
FunctionImportedAsVisitor,
@ -24,6 +24,8 @@ from codeflash.models.models import (
TestsInFile,
TestType,
)
import platform
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -87,7 +89,38 @@ codeflash_wrap_perfonly_string = """def codeflash_wrap(codeflash_wrapped, codefl
"""
def test_perfinjector_bubble_sort() -> None:
def build_expected_unittest_imports(extra_imports: str = "") -> str:
imports = """import gc
import os
import sqlite3
import time
import unittest
import dill as pickle"""
if platform.system() != "Windows":
imports += "\nimport timeout_decorator"
if extra_imports:
imports += "\n" + extra_imports
return imports
def build_expected_pytest_imports(extra_imports: str = "") -> str:
"""Helper to build platform-aware imports for pytest tests."""
imports = """import gc
import os
import time
import pytest"""
if extra_imports:
imports += "\n" + extra_imports
return imports
# create a temporary directory for the test results
@pytest.fixture
def tmp_dir():
with tempfile.TemporaryDirectory() as tmpdirname:
yield Path(tmpdirname)
def test_perfinjector_bubble_sort(tmp_dir) -> None:
code = """import unittest
from code_to_optimize.bubble_sort import sorter
@ -106,54 +139,27 @@ class TestPigLatin(unittest.TestCase):
input = list(reversed(range(5000)))
self.assertEqual(sorter(input), list(range(5000)))
"""
expected = """import gc
imports = """import gc
import os
import sqlite3
import time
import unittest
import dill as pickle
import timeout_decorator
from code_to_optimize.bubble_sort import sorter
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
"""
expected += """test_stdout_tag = f'{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}'
"""
expected += """print(f'!$######{{test_stdout_tag}}######$!')
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f'!######{{test_stdout_tag}}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
def test_sort(self):
import dill as pickle"""
if platform.system() != "Windows":
imports += "\nimport timeout_decorator"
imports += "\n\nfrom code_to_optimize.bubble_sort import sorter"
wrapper_func = codeflash_wrap_string
test_class_header = "class TestPigLatin(unittest.TestCase):"
test_decorator = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
expected = imports + "\n\n\n" + wrapper_func + "\n" + test_class_header + "\n\n"
if test_decorator:
expected += test_decorator + "\n"
expected += """ def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
@ -169,7 +175,8 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input), list(range(5000)))
codeflash_con.close()
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
with (tmp_dir / "test_sort.py").open("w") as f:
f.write(code)
f.flush()
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path(f.name))
@ -186,11 +193,11 @@ class TestPigLatin(unittest.TestCase):
os.chdir(original_cwd)
assert success
assert new_test.replace('"', "'") == expected.format(
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
module_path=Path(f.name).stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
def test_perfinjector_only_replay_test() -> None:
def test_perfinjector_only_replay_test(tmp_dir) -> None:
code = """import dill as pickle
import pytest
from codeflash.tracing.replay_test import get_next_arg_and_return
@ -269,7 +276,7 @@ def test_prepare_image_for_yolo():
assert compare_results(return_val_1, ret)
codeflash_con.close()
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
with (tmp_dir / "test_return_values.py").open("w") as f:
f.write(code)
f.flush()
func = FunctionToOptimize(function_name="prepare_image_for_yolo", parents=[], file_path=Path("module.py"))
@ -282,7 +289,7 @@ def test_prepare_image_for_yolo():
os.chdir(original_cwd)
assert success
assert new_test.replace('"', "'") == expected.format(
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
module_path=Path(f.name).stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
@ -389,7 +396,7 @@ def test_sort():
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
success, new_perf_test = inject_profiling_into_existing_test(
@ -404,7 +411,7 @@ def test_sort():
assert new_perf_test is not None
assert new_perf_test.replace('"', "'") == expected_perfonly.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
with test_path.open("w") as f:
@ -532,7 +539,8 @@ result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
line_profiler_output_file=line_profiler_output_file,
)
tmp_lpr = list(line_profile_results["timings"].keys())
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 2
if sys.platform != "win32":
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 2
finally:
if computed_fn_opt:
func_optimizer.write_code_and_helpers(
@ -643,11 +651,11 @@ def test_sort_parametrized(input, expected_output):
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perfonly.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
#
# Overwrite old test with new instrumented test
@ -799,7 +807,8 @@ result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
line_profiler_output_file=line_profiler_output_file,
)
tmp_lpr = list(line_profile_results["timings"].keys())
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
if sys.platform != "win32":
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
finally:
if computed_fn_opt:
func_optimizer.write_code_and_helpers(
@ -916,7 +925,7 @@ def test_sort_parametrized_loop(input, expected_output):
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
# Overwrite old test with new instrumented test
@ -925,7 +934,7 @@ def test_sort_parametrized_loop(input, expected_output):
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
# Overwrite old test with new instrumented test
@ -1154,7 +1163,8 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
line_profiler_output_file=line_profiler_output_file,
)
tmp_lpr = list(line_profile_results["timings"].keys())
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 6
if sys.platform != "win32":
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 6
finally:
if computed_fn_opt:
func_optimizer.write_code_and_helpers(
@ -1271,12 +1281,12 @@ def test_sort():
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
# Overwrite old test with new instrumented test
@ -1432,7 +1442,8 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
line_profiler_output_file=line_profiler_output_file,
)
tmp_lpr = list(line_profile_results["timings"].keys())
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
if sys.platform != "win32":
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
finally:
if computed_fn_opt is True:
func_optimizer.write_code_and_helpers(
@ -1446,6 +1457,7 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
def test_perfinjector_bubble_sort_unittest_results() -> None:
code = """import unittest
from code_to_optimize.bubble_sort import sorter
@ -1466,8 +1478,74 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, list(range(50)))
"""
expected = (
"""import gc
is_windows = platform.system() == "Windows"
if is_windows:
expected = (
"""import gc
import os
import sqlite3
import time
import unittest
import dill as pickle
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_string
+ """
class TestPigLatin(unittest.TestCase):
def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
input = [5, 4, 3, 2, 1, 0]
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input)
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input)
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(50)))
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input)
self.assertEqual(output, list(range(50)))
codeflash_con.close()
"""
)
expected_perf = (
"""import gc
import os
import time
import unittest
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_perfonly_string
+ """
class TestPigLatin(unittest.TestCase):
def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
input = [5, 4, 3, 2, 1, 0]
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, input)
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, input)
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(50)))
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input)
self.assertEqual(output, list(range(50)))
"""
)
else:
expected = (
"""import gc
import os
import sqlite3
import time
@ -1480,8 +1558,8 @@ from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_string
+ """
+ codeflash_wrap_string
+ """
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
@ -1502,9 +1580,9 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, list(range(50)))
codeflash_con.close()
"""
)
expected_perf = (
"""import gc
)
expected_perf = (
"""import gc
import os
import time
import unittest
@ -1515,8 +1593,8 @@ from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_perfonly_string
+ """
+ codeflash_wrap_perfonly_string
+ """
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
@ -1532,7 +1610,7 @@ class TestPigLatin(unittest.TestCase):
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input)
self.assertEqual(output, list(range(50)))
"""
)
)
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
Path(__file__).parent.resolve()
@ -1580,11 +1658,11 @@ class TestPigLatin(unittest.TestCase):
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
#
# Overwrite old test with new instrumented test
@ -1744,28 +1822,18 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, expected_output)
"""
expected_behavior = (
"""import gc
import os
import sqlite3
import time
import unittest
import dill as pickle
import timeout_decorator
from parameterized import parameterized
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_string
+ """
class TestPigLatin(unittest.TestCase):
# Build expected behavior output with platform-aware imports
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@timeout_decorator.timeout(15)
def test_sort(self, input, expected_output):
"""
if test_decorator_behavior:
test_class_behavior += test_decorator_behavior + "\n"
test_class_behavior += """ def test_sort(self, input, expected_output):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
@ -1775,32 +1843,32 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, expected_output)
codeflash_con.close()
"""
)
expected_perf = (
"""import gc
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
# Build expected perf output with platform-aware imports
imports_perf = """import gc
import os
import time
import unittest
import timeout_decorator
from parameterized import parameterized
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_perfonly_string
+ """
class TestPigLatin(unittest.TestCase):
if platform.system() != "Windows":
imports_perf += "\nimport timeout_decorator"
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class_perf = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@timeout_decorator.timeout(15)
def test_sort(self, input, expected_output):
"""
if test_decorator_perf:
test_class_perf += test_decorator_perf + "\n"
test_class_perf += """ def test_sort(self, input, expected_output):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, input)
self.assertEqual(output, expected_output)
"""
)
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
Path(__file__).parent.resolve()
@ -1837,13 +1905,13 @@ class TestPigLatin(unittest.TestCase):
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
assert new_test_perf is not None
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
#
@ -2003,26 +2071,17 @@ class TestPigLatin(unittest.TestCase):
output = sorter(input)
self.assertEqual(output, expected_output)"""
expected_behavior = (
"""import gc
import os
import sqlite3
import time
import unittest
import dill as pickle
import timeout_decorator
from code_to_optimize.bubble_sort import sorter
# Build expected behavior output with platform-aware imports
imports_behavior = build_expected_unittest_imports()
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class_behavior = """class TestPigLatin(unittest.TestCase):
"""
+ codeflash_wrap_string
+ """
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
def test_sort(self):
if test_decorator_behavior:
test_class_behavior += test_decorator_behavior + "\n"
test_class_behavior += """ def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
@ -2037,26 +2096,28 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, expected_output)
codeflash_con.close()
"""
)
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
expected_perf = (
"""import gc
# Build expected perf output with platform-aware imports
imports_perf = """import gc
import os
import time
import unittest
import timeout_decorator
from code_to_optimize.bubble_sort import sorter
"""
if platform.system() != "Windows":
imports_perf += "\nimport timeout_decorator"
imports_perf += "\n\nfrom code_to_optimize.bubble_sort import sorter"
else:
imports_perf += "\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class_perf = """class TestPigLatin(unittest.TestCase):
"""
+ codeflash_wrap_perfonly_string
+ """
class TestPigLatin(unittest.TestCase):
@timeout_decorator.timeout(15)
def test_sort(self):
if test_decorator_perf:
test_class_perf += test_decorator_perf + "\n"
test_class_perf += """ def test_sort(self):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
inputs = [[5, 4, 3, 2, 1, 0], [5.0, 4.0, 3.0, 2.0, 1.0, 0.0], list(reversed(range(50)))]
expected_outputs = [[0, 1, 2, 3, 4, 5], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], list(range(50))]
@ -2066,7 +2127,8 @@ class TestPigLatin(unittest.TestCase):
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, input)
self.assertEqual(output, expected_output)
"""
)
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
Path(__file__).parent.resolve()
@ -2103,11 +2165,11 @@ class TestPigLatin(unittest.TestCase):
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
#
# # Overwrite old test with new instrumented test
@ -2269,28 +2331,18 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, expected_output)
"""
expected_behavior = (
"""import gc
import os
import sqlite3
import time
import unittest
import dill as pickle
import timeout_decorator
from parameterized import parameterized
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_string
+ """
class TestPigLatin(unittest.TestCase):
# Build expected behavior output with platform-aware imports
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class_behavior = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@timeout_decorator.timeout(15)
def test_sort(self, input, expected_output):
"""
if test_decorator_behavior:
test_class_behavior += test_decorator_behavior + "\n"
test_class_behavior += """ def test_sort(self, input, expected_output):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
@ -2301,33 +2353,33 @@ class TestPigLatin(unittest.TestCase):
self.assertEqual(output, expected_output)
codeflash_con.close()
"""
)
expected_perf = (
"""import gc
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
# Build expected perf output with platform-aware imports
imports_perf = """import gc
import os
import time
import unittest
import timeout_decorator
from parameterized import parameterized
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_perfonly_string
+ """
class TestPigLatin(unittest.TestCase):
if platform.system() != "Windows":
imports_perf += "\nimport timeout_decorator"
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class_perf = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
@timeout_decorator.timeout(15)
def test_sort(self, input, expected_output):
"""
if test_decorator_perf:
test_class_perf += test_decorator_perf + "\n"
test_class_perf += """ def test_sort(self, input, expected_output):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
for i in range(2):
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, input)
self.assertEqual(output, expected_output)
"""
)
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
test_path = (
Path(__file__).parent.resolve()
@ -2362,11 +2414,11 @@ class TestPigLatin(unittest.TestCase):
assert new_test_behavior is not None
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
assert new_test_perf.replace('"', "'") == expected_perf.format(
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
#
# Overwrite old test with new instrumented test
@ -2663,7 +2715,7 @@ def test_class_name_A_function_name():
assert success
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
module_path="tests.pytest.test_class_function_instrumentation_temp",
).replace('"', "'")
@ -2734,7 +2786,7 @@ def test_common_tags_1():
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="tests.pytest.test_wrong_function_instrumentation_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
finally:
test_path.unlink(missing_ok=True)
@ -2797,7 +2849,7 @@ def test_sort():
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="tests.pytest.test_conditional_instrumentation_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
).replace('"', "'")
finally:
test_path.unlink(missing_ok=True)
@ -2874,7 +2926,7 @@ def test_sort():
assert success
formatted_expected = expected.format(
module_path="tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
)
assert new_test is not None
assert new_test.replace('"', "'") == formatted_expected.replace('"', "'")
@ -2882,7 +2934,7 @@ def test_sort():
test_path.unlink(missing_ok=True)
def test_class_method_instrumentation() -> None:
def test_class_method_instrumentation(tmp_path: Path) -> None:
code = """from codeflash.optimization.optimizer import Optimizer
def test_code_replacement10() -> None:
get_code_output = '''random code'''
@ -2952,24 +3004,24 @@ def test_code_replacement10() -> None:
"""
)
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
func = FunctionToOptimize(
function_name="get_code_optimization_context",
parents=[FunctionParent("Optimizer", "ClassDef")],
file_path=Path(f.name),
)
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
Path(f.name), [CodePosition(22, 28), CodePosition(28, 28)], func, Path(f.name).parent, "pytest"
)
os.chdir(original_cwd)
test_file_path = tmp_path / "test_class_method_instrumentation.py"
test_file_path.write_text(code, encoding="utf-8")
func = FunctionToOptimize(
function_name="get_code_optimization_context",
parents=[FunctionParent("Optimizer", "ClassDef")],
file_path=test_file_path,
)
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_file_path, [CodePosition(22, 28), CodePosition(28, 28)], func, test_file_path.parent, "pytest"
)
os.chdir(original_cwd)
assert success
assert new_test.replace('"', "'") == expected.replace('"', "'").format(
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
module_path=test_file_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
)
@ -3034,7 +3086,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_time_correction_instrumentation_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
# Overwrite old test with new instrumented test
with test_path.open("w") as f:
@ -3101,30 +3153,29 @@ class TestPigLatin(unittest.TestCase):
output = accurate_sleepfunc(n)
"""
expected = (
"""import gc
# Build expected output with platform-aware imports
imports = """import gc
import os
import time
import unittest
import timeout_decorator
from parameterized import parameterized
from code_to_optimize.sleeptime import accurate_sleepfunc
"""
+ codeflash_wrap_perfonly_string
+ """
class TestPigLatin(unittest.TestCase):
if platform.system() != "Windows":
imports += "\nimport timeout_decorator"
imports += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.sleeptime import accurate_sleepfunc"
test_decorator = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
test_class = """class TestPigLatin(unittest.TestCase):
@parameterized.expand([(0.01, 0.01), (0.02, 0.02)])
@timeout_decorator.timeout(15)
def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time):
"""
if test_decorator:
test_class += test_decorator + "\n"
test_class += """ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time):
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
output = codeflash_wrap(accurate_sleepfunc, '{module_path}', 'TestPigLatin', 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, n)
"""
)
expected = imports + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sleeptime.py").resolve()
test_path = (
Path(__file__).parent.resolve()
@ -3153,7 +3204,7 @@ class TestPigLatin(unittest.TestCase):
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.unittest.test_time_correction_instrumentation_unittest_temp",
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
).replace('"', "'")
# Overwrite old test with new instrumented test
with test_path.open("w") as f:

View file

@ -18,6 +18,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
import time
try:
import sqlalchemy
@ -156,6 +157,9 @@ def test_picklepatch_with_database_connection():
with pytest.raises(PicklePlaceholderAccessError):
reloaded["connection"].execute("SELECT 1")
cursor.close()
conn.close()
def test_picklepatch_with_generator():
"""Test that a data structure containing a generator is replaced by
@ -287,17 +291,26 @@ def test_run_and_parse_picklepatch() -> None:
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
# Close the connection to allow file cleanup on Windows
conn.close()
time.sleep(1)
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
test_name, total_time, function_time, percent = \
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
# Handle the case where function runs too fast to be measured
unused_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"]
if unused_socket_results:
test_name, total_time, function_time, percent = unused_socket_results[0]
assert total_time >= 0.0
# Function might be too fast, so we allow 0.0 function_time
assert function_time >= 0.0
assert percent >= 0.0
used_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_used_socket.bubble_sort_with_used_socket"]
# on windows , if the socket is not used we might not have resultssss
if used_socket_results:
test_name, total_time, function_time, percent = used_socket_results[0]
assert total_time >= 0.0
assert function_time >= 0.0
assert percent >= 0.0
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
@ -318,7 +331,9 @@ def test_run_and_parse_picklepatch() -> None:
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
conn.close()
conn.close()
time.sleep(1)
# Generate replay test
generate_replay_test(output_file, replay_tests_dir)
@ -510,4 +525,3 @@ def bubble_sort_with_used_socket(data_container):
shutil.rmtree(replay_tests_dir, ignore_errors=True)
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
fto_used_socket_path.write_text(original_fto_used_socket_code)

View file

@ -38,6 +38,12 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
self.test_rc_path = "test_shell_rc"
self.api_key = "cf-1234567890abcdef"
os.environ["SHELL"] = "/bin/bash" # Set a default shell for testing
# Set up platform-specific export syntax
if os.name == "nt": # Windows
self.api_key_export = f'set CODEFLASH_API_KEY={self.api_key}'
else: # Unix-like systems
self.api_key_export = f'export CODEFLASH_API_KEY="{self.api_key}"'
def tearDown(self):
"""Cleanup the temporary shell configuration file after testing."""
@ -50,25 +56,37 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path:
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
"builtins.open", mock_open(read_data=f'{self.api_key_export}\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
with patch(
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
with patch(
"builtins.open", mock_open(read_data=f'#export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), None)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
with patch(
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY={self.api_key}\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
if os.name != "nt":
with patch(
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
with patch(
"builtins.open", mock_open(read_data=f'#export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), None)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
with patch(
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY={self.api_key}\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
elif os.name == "nt":
with patch(
"builtins.open", mock_open(read_data=f'REM set CODEFLASH_API_KEY={self.api_key}\n')
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), None)
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
@ -83,23 +101,41 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
def test_malformed_api_key_export(self, mock_get_shell_rc_path):
"""Test with a malformed API key export."""
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch("builtins.open", mock_open(read_data=f"export API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
if os.name == "nt":
with patch("builtins.open", mock_open(read_data=f"set API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"set CODEFLASH_API_KEY=sk-{self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
else:
with patch("builtins.open", mock_open(read_data=f"export API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_multiple_api_key_exports(self, mock_get_shell_rc_path):
"""Test with multiple API key exports."""
mock_get_shell_rc_path.return_value = self.test_rc_path
if os.name == "nt": # Windows
first_export = 'set CODEFLASH_API_KEY=cf-firstkey'
second_export = f'set CODEFLASH_API_KEY={self.api_key}'
else:
first_export = 'export CODEFLASH_API_KEY="cf-firstkey"'
second_export = f'export CODEFLASH_API_KEY="{self.api_key}"'
with patch(
"builtins.open",
mock_open(read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'),
mock_open(read_data=f'{first_export}\n{second_export}\n'),
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
@ -109,7 +145,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open",
mock_open(read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'),
mock_open(read_data=f'# Setting API Key\n{self.api_key_export}\n# Done\n'),
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
@ -117,7 +153,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
def test_api_key_in_comment(self, mock_get_shell_rc_path):
"""Test with API key export in a comment."""
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch("builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')):
with patch("builtins.open", mock_open(read_data=f'# {self.api_key_export}\n')):
self.assertIsNone(read_api_key_from_shell_config())
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")

View file

@ -34,12 +34,12 @@ class TestUnittestRunnerSorter(unittest.TestCase):
tests_project_rootdir=cur_dir_path.parent,
)
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
test_file_path = Path(temp_dir) / "test_xx.py"
test_files = TestFiles(
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
fp.write(code.encode("utf-8"))
fp.flush()
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
@ -78,12 +78,12 @@ def test_sort():
else:
test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path)
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
test_file_path = Path(temp_dir) / "test_xx.py"
test_files = TestFiles(
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
fp.write(code.encode("utf-8"))
fp.flush()
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
@ -125,12 +125,12 @@ def test_sort():
else:
test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path)
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
test_file_path = Path(temp_dir) / "test_xx.py"
test_files = TestFiles(
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
fp.write(code.encode("utf-8"))
fp.flush()
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,

View file

@ -9,6 +9,7 @@ from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
import time
def test_trace_benchmarks() -> None:
@ -154,7 +155,7 @@ from codeflash.benchmarking.replay_test import get_next_arg_and_return
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
functions = ['compute_and_sort', 'sorter']
trace_file_path = r"{output_file}"
trace_file_path = r"{output_file.as_posix()}"
def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort_test_compute_and_sort():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100):
@ -196,6 +197,8 @@ def test_trace_multithreaded_benchmark() -> None:
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
conn.close()
# Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
@ -204,9 +207,9 @@ def test_trace_multithreaded_benchmark() -> None:
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
assert total_time >= 0.0
assert function_time >= 0.0
assert percent >= 0.0
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls
@ -279,8 +282,10 @@ def test_trace_benchmark_decorator() -> None:
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
# Close connection
cursor.close()
conn.close()
time.sleep(2)
finally:
# cleanup
output_file.unlink(missing_ok=True)
time.sleep(1)

View file

@ -3,7 +3,6 @@ import dataclasses
import pickle
import sqlite3
import sys
import tempfile
import threading
import time
from collections.abc import Generator
@ -59,29 +58,26 @@ class TraceConfig:
class TestTracer:
@pytest.fixture
def trace_config(self) -> Generator[Path, None, None]:
def trace_config(self, tmp_path: Path) -> Generator[TraceConfig, None, None]:
"""Create a temporary pyproject.toml config file."""
# Create a temporary directory structure
temp_dir = Path(tempfile.mkdtemp())
tests_dir = temp_dir / "tests"
tests_dir = tmp_path / "tests"
tests_dir.mkdir(exist_ok=True)
# Use the current working directory as module root so test files are included
current_dir = Path.cwd()
with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False, dir=temp_dir) as f:
f.write(f"""
config_path = tmp_path / "pyproject.toml"
config_path.write_text(f"""
[tool.codeflash]
module-root = "{current_dir}"
tests-root = "{tests_dir}"
module-root = "{current_dir.as_posix()}"
tests-root = "{tests_dir.as_posix()}"
test-framework = "pytest"
ignore-paths = []
""")
config_path = Path(f.name)
with tempfile.NamedTemporaryFile(suffix=".trace", delete=False) as f:
trace_path = Path(f.name)
trace_path.unlink(missing_ok=True) # Remove the file, we just want the path
replay_test_pkl_path = temp_dir / "replay_test.pkl"
""", encoding="utf-8")
trace_path = tmp_path / "trace_file.trace"
replay_test_pkl_path = tmp_path / "replay_test.pkl"
config, found_config_path = parse_config_file(config_path)
trace_config = TraceConfig(
trace_file=trace_path,
@ -92,11 +88,6 @@ ignore-paths = []
)
yield trace_config
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
trace_path.unlink(missing_ok=True)
replay_test_pkl_path.unlink(missing_ok=True)
@pytest.fixture(autouse=True)
def reset_tracer_state(self) -> Generator[None, None, None]:

View file

@ -133,7 +133,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
assert len(discovered_tests) == 1
assert len(discovered_tests["dummy_code.dummy_function"]) == 2
dummy_tests = discovered_tests["dummy_code.dummy_function"]
assert all(test.tests_in_file.test_file == test_file_path for test in dummy_tests)
assert all(test.tests_in_file.test_file.resolve() == test_file_path.resolve() for test in dummy_tests)
assert {test.tests_in_file.test_function for test in dummy_tests} == {
"test_dummy_parametrized_function[True]",
"test_dummy_function",
@ -204,16 +204,13 @@ def test_discover_tests_pytest_with_multi_level_dirs():
# Check if the test files at all levels are discovered
assert len(discovered_tests) == 3
assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path
assert (
next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
== level1_test_file_path
)
discovered_root_test = next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file
assert discovered_root_test.resolve() == root_test_file_path.resolve()
discovered_level1_test = next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
assert discovered_level1_test.resolve() == level1_test_file_path.resolve()
assert (
next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
== level2_test_file_path
)
discovered_level2_test = next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
assert discovered_level2_test.resolve() == level2_test_file_path.resolve()
def test_discover_tests_pytest_dirs():
@ -295,20 +292,15 @@ def test_discover_tests_pytest_dirs():
# Check if the test files at all levels are discovered
assert len(discovered_tests) == 4
assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path
assert (
next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
== level1_test_file_path
)
assert (
next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
== level2_test_file_path
)
discovered_root_test = next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file
assert discovered_root_test.resolve() == root_test_file_path.resolve()
discovered_level1_test = next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
assert discovered_level1_test.resolve() == level1_test_file_path.resolve()
discovered_level2_test = next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
assert discovered_level2_test.resolve() == level2_test_file_path.resolve()
assert (
next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file
== level3_test_file_path
)
discovered_level3_test = next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file
assert discovered_level3_test.resolve() == level3_test_file_path.resolve()
def test_discover_tests_pytest_with_class():
@ -342,10 +334,8 @@ def test_discover_tests_pytest_with_class():
# Check if the test class and method are discovered
assert len(discovered_tests) == 1
assert (
next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file
== test_file_path
)
discovered_class_test = next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file
assert discovered_class_test.resolve() == test_file_path.resolve()
def test_discover_tests_pytest_with_double_nested_directories():
@ -383,12 +373,10 @@ def test_discover_tests_pytest_with_double_nested_directories():
# Check if the test class and method are discovered
assert len(discovered_tests) == 1
assert (
next(
iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])
).tests_in_file.test_file
== test_file_path
)
discovered_nested_test = next(
iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])
).tests_in_file.test_file
assert discovered_nested_test.resolve() == test_file_path.resolve()
def test_discover_tests_with_code_in_dir_and_test_in_subdir():
@ -433,7 +421,8 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
# Check if the test file is discovered and associated with the code file
assert len(discovered_tests) == 1
assert next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file == test_file_path
discovered_test_file = next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file
assert discovered_test_file.resolve() == test_file_path.resolve()
def test_discover_tests_pytest_with_nested_class():
@ -469,10 +458,8 @@ def test_discover_tests_pytest_with_nested_class():
# Check if the test for the nested class method is discovered
assert len(discovered_tests) == 1
assert (
next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file
== test_file_path
)
discovered_inner_test = next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file
assert discovered_inner_test.resolve() == test_file_path.resolve()
def test_discover_tests_pytest_separate_moduledir():
@ -509,7 +496,8 @@ def test_discover_tests_pytest_separate_moduledir():
# Check if the test for the nested class method is discovered
assert len(discovered_tests) == 1
assert next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file == test_file_path
discovered_test_file = next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file
assert discovered_test_file.resolve() == test_file_path.resolve()
def test_unittest_discovery_with_pytest():
@ -554,7 +542,7 @@ class TestCalculator(unittest.TestCase):
assert "calculator.Calculator.add" in discovered_tests
assert len(discovered_tests["calculator.Calculator.add"]) == 1
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
assert calculator_test.tests_in_file.test_file == test_file_path
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
assert calculator_test.tests_in_file.test_function == "test_add"
@ -622,7 +610,7 @@ class TestCalculator(ExtendedTestCase):
assert "calculator.Calculator.add" in discovered_tests
assert len(discovered_tests["calculator.Calculator.add"]) == 1
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
assert calculator_test.tests_in_file.test_file == test_file_path
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
assert calculator_test.tests_in_file.test_function == "test_add"
@ -720,7 +708,7 @@ class TestCalculator(unittest.TestCase):
assert "calculator.Calculator.add" in discovered_tests
assert len(discovered_tests["calculator.Calculator.add"]) == 1
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
assert calculator_test.tests_in_file.test_file == test_file_path
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
assert calculator_test.tests_in_file.test_function == "test_add_with_parameters"

View file

@ -9,6 +9,8 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
from codeflash.context.unused_definition_remover import revert_unused_helper_functions
@pytest.fixture
@ -225,6 +227,15 @@ def helper_function_2(x):
def test_no_unused_helpers_no_revert(temp_project):
"""Test that when all helpers are still used, nothing is reverted."""
temp_dir, main_file, test_cfg = temp_project
# Store original content to verify nothing changes
original_content = main_file.read_text()
revert_unused_helper_functions(temp_dir, [], {})
# Verify the file content remains unchanged
assert main_file.read_text() == original_content, "File should remain unchanged when no helpers to revert"
# Optimized version that still calls both helpers
optimized_code = """
@ -308,17 +319,23 @@ def helper_function_1(x):
def helper_function_2(x):
\"\"\"Second helper function.\"\"\"
return x * 3
def helper_function_1(y): # Duplicate name to test line 575
\"\"\"Overloaded helper function.\"\"\"
return y + 10
""")
# Optimized version that only calls one helper
# Optimized version that only calls one helper with aliased import
optimized_code = """
```python:main.py
from helpers import helper_function_1
from helpers import helper_function_1 as h1
import helpers as h_module
def entrypoint_function(n):
\"\"\"Optimized function that only calls one helper.\"\"\"
result1 = helper_function_1(n)
return result1 + n * 3 # Inlined helper_function_2
\"\"\"Optimized function that only calls one helper with aliasing.\"\"\"
result1 = h1(n) # Using aliased import
# Inlined helper_function_2 functionality: n * 3
return result1 + n * 3 # Fully inlined helper_function_2
```
"""
@ -1460,3 +1477,595 @@ class MathUtils:
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_entrypoint_with_async_helpers():
"""Test that unused async helper functions are correctly detected when entrypoint is async."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with async entrypoint and async helpers
main_file = temp_dir / "main.py"
main_file.write_text("""
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function.\"\"\"
return x * 3
async def async_entrypoint(n):
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
result1 = await async_helper_1(n)
result2 = await async_helper_2(n)
return result1 + result2
""")
# Optimized version that only calls one async helper
optimized_code = """
```python:main.py
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function - should be unused.\"\"\"
return x * 3
async def async_entrypoint(n):
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
result1 = await async_helper_1(n)
return result1 + n * 3 # Inlined async_helper_2
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="async_entrypoint",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect async_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"async_helper_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_sync_entrypoint_with_async_helpers():
"""Test that unused async helper functions are detected when entrypoint is sync."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with sync entrypoint and async helpers
main_file = temp_dir / "main.py"
main_file.write_text("""
import asyncio
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function.\"\"\"
return x * 3
def sync_entrypoint(n):
\"\"\"Sync entrypoint function that calls async helpers.\"\"\"
result1 = asyncio.run(async_helper_1(n))
result2 = asyncio.run(async_helper_2(n))
return result1 + result2
""")
# Optimized version that only calls one async helper
optimized_code = """
```python:main.py
import asyncio
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function - should be unused.\"\"\"
return x * 3
def sync_entrypoint(n):
\"\"\"Optimized sync entrypoint that only calls one async helper.\"\"\"
result1 = asyncio.run(async_helper_1(n))
return result1 + n * 3 # Inlined async_helper_2
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for sync function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="sync_entrypoint",
parents=[]
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect async_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"async_helper_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_mixed_sync_and_async_helpers():
"""Test detection when both sync and async helpers are mixed."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with mixed sync and async helpers
main_file = temp_dir / "main.py"
main_file.write_text("""
import asyncio
def sync_helper_1(x):
\"\"\"Sync helper function.\"\"\"
return x * 2
async def async_helper_1(x):
\"\"\"Async helper function.\"\"\"
return x * 3
def sync_helper_2(x):
\"\"\"Another sync helper function.\"\"\"
return x * 4
async def async_helper_2(x):
\"\"\"Another async helper function.\"\"\"
return x * 5
async def mixed_entrypoint(n):
\"\"\"Async entrypoint function that calls both sync and async helpers.\"\"\"
sync_result = sync_helper_1(n)
async_result = await async_helper_1(n)
sync_result2 = sync_helper_2(n)
async_result2 = await async_helper_2(n)
return sync_result + async_result + sync_result2 + async_result2
""")
# Optimized version that only calls some helpers
optimized_code = """
```python:main.py
import asyncio
def sync_helper_1(x):
\"\"\"Sync helper function.\"\"\"
return x * 2
async def async_helper_1(x):
\"\"\"Async helper function.\"\"\"
return x * 3
def sync_helper_2(x):
\"\"\"Another sync helper function - should be unused.\"\"\"
return x * 4
async def async_helper_2(x):
\"\"\"Another async helper function - should be unused.\"\"\"
return x * 5
async def mixed_entrypoint(n):
\"\"\"Optimized async entrypoint that only calls some helpers.\"\"\"
sync_result = sync_helper_1(n)
async_result = await async_helper_1(n)
return sync_result + async_result + n * 4 + n * 5 # Inlined both helper_2 functions
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="mixed_entrypoint",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect both sync_helper_2 and async_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"sync_helper_2", "async_helper_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_class_methods():
"""Test unused async method detection in classes."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with class containing async methods
main_file = temp_dir / "main.py"
main_file.write_text("""
class AsyncProcessor:
async def entrypoint_method(self, n):
\"\"\"Async main method that calls async helper methods.\"\"\"
result1 = await self.async_helper_method_1(n)
result2 = await self.async_helper_method_2(n)
return result1 + result2
async def async_helper_method_1(self, x):
\"\"\"First async helper method.\"\"\"
return x * 2
async def async_helper_method_2(self, x):
\"\"\"Second async helper method.\"\"\"
return x * 3
def sync_helper_method(self, x):
\"\"\"Sync helper method.\"\"\"
return x * 4
""")
# Optimized version that only calls one async helper
optimized_code = """
```python:main.py
class AsyncProcessor:
async def entrypoint_method(self, n):
\"\"\"Optimized async method that only calls one helper.\"\"\"
result1 = await self.async_helper_method_1(n)
return result1 + n * 3 # Inlined async_helper_method_2
async def async_helper_method_1(self, x):
\"\"\"First async helper method.\"\"\"
return x * 2
async def async_helper_method_2(self, x):
\"\"\"Second async helper method - should be unused.\"\"\"
return x * 3
def sync_helper_method(self, x):
\"\"\"Sync helper method - should be unused.\"\"\"
return x * 4
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async class method
from codeflash.models.models import FunctionParent
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="entrypoint_method",
parents=[FunctionParent(name="AsyncProcessor", type="ClassDef")],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect async_helper_method_2 as unused (sync_helper_method may not be discovered as helper)
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"AsyncProcessor.async_helper_method_2"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_helper_revert_functionality():
"""Test that unused async helper functions are correctly reverted to original definitions."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with async functions
main_file = temp_dir / "main.py"
main_file.write_text("""
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Second async helper function.\"\"\"
return x * 3
async def async_entrypoint(n):
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
result1 = await async_helper_1(n)
result2 = await async_helper_2(n)
return result1 + result2
""")
# Optimized version that only calls one helper and modifies the unused one
optimized_code = """
```python:main.py
async def async_helper_1(x):
\"\"\"First async helper function.\"\"\"
return x * 2
async def async_helper_2(x):
\"\"\"Modified async helper function - should be reverted.\"\"\"
return x * 10 # This change should be reverted
async def async_entrypoint(n):
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
result1 = await async_helper_1(n)
return result1 + n * 3 # Inlined async_helper_2
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="async_entrypoint",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Store original helper code
original_helper_code = {main_file: main_file.read_text()}
# Apply optimization and test reversion
optimizer.replace_function_and_helpers_with_optimized_code(
code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code
)
# Check final file content
final_content = main_file.read_text()
# The entrypoint should be optimized
assert "result1 + n * 3" in final_content, "Async entrypoint function should be optimized"
# async_helper_2 should be reverted to original (return x * 3, not x * 10)
assert "return x * 3" in final_content, "async_helper_2 should be reverted to original"
assert "return x * 10" not in final_content, "async_helper_2 should not contain the modified version"
# async_helper_1 should remain (it's still called)
assert "async def async_helper_1(x):" in final_content, "async_helper_1 should still exist"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
def test_async_generators_and_coroutines():
"""Test detection with async generators and coroutines."""
temp_dir = Path(tempfile.mkdtemp())
try:
# Main file with async generators and coroutines
main_file = temp_dir / "main.py"
main_file.write_text("""
import asyncio
async def async_generator_helper(n):
\"\"\"Async generator helper.\"\"\"
for i in range(n):
yield i * 2
async def coroutine_helper(x):
\"\"\"Coroutine helper.\"\"\"
await asyncio.sleep(0.1)
return x * 3
async def another_coroutine_helper(x):
\"\"\"Another coroutine helper.\"\"\"
await asyncio.sleep(0.1)
return x * 4
async def async_entrypoint_with_generators(n):
\"\"\"Async entrypoint function that uses generators and coroutines.\"\"\"
results = []
async for value in async_generator_helper(n):
results.append(value)
final_result = await coroutine_helper(sum(results))
another_result = await another_coroutine_helper(n)
return final_result + another_result
""")
# Optimized version that doesn't use one of the coroutines
optimized_code = """
```python:main.py
import asyncio
async def async_generator_helper(n):
\"\"\"Async generator helper.\"\"\"
for i in range(n):
yield i * 2
async def coroutine_helper(x):
\"\"\"Coroutine helper.\"\"\"
await asyncio.sleep(0.1)
return x * 3
async def another_coroutine_helper(x):
\"\"\"Another coroutine helper - should be unused.\"\"\"
await asyncio.sleep(0.1)
return x * 4
async def async_entrypoint_with_generators(n):
\"\"\"Optimized async entrypoint that inlines one coroutine.\"\"\"
results = []
async for value in async_generator_helper(n):
results.append(value)
final_result = await coroutine_helper(sum(results))
return final_result + n * 4 # Inlined another_coroutine_helper
```
"""
# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
# Create FunctionToOptimize instance for async function
function_to_optimize = FunctionToOptimize(
file_path=main_file,
function_name="async_entrypoint_with_generators",
parents=[],
is_async=True
)
# Create function optimizer
optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)
# Get original code context
ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
# Should detect another_coroutine_helper as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
expected_unused = {"another_coroutine_helper"}
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
finally:
# Cleanup
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)

65
tests/test_worktree.py Normal file
View file

@ -0,0 +1,65 @@
from argparse import Namespace
from pathlib import Path
import pytest
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.optimization.optimizer import Optimizer
def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
repo_root = Path(__file__).resolve().parent.parent
project_root = repo_root / "code_to_optimize" / "code_directories" / "nested_module_root"
monkeypatch.setattr("codeflash.optimization.optimizer.git_root_dir", lambda: project_root)
args = Namespace()
args.benchmark = False
args.benchmarks_root = None
args.config_file = project_root / "pyproject.toml"
args.file = project_root / "src" / "app" / "main.py"
args.worktree = True
new_args = process_pyproject_config(args)
optimizer = Optimizer(new_args)
worktree_dir = repo_root / "worktree"
optimizer.mirror_paths_for_worktree_mode(worktree_dir)
assert optimizer.args.project_root == worktree_dir / "src"
assert optimizer.args.test_project_root == worktree_dir / "src"
assert optimizer.args.module_root == worktree_dir / "src" / "app"
assert optimizer.args.tests_root == worktree_dir / "src" / "tests"
assert optimizer.args.file == worktree_dir / "src" / "app" / "main.py"
assert optimizer.test_cfg.tests_root == worktree_dir / "src" / "tests"
assert optimizer.test_cfg.project_root_path == worktree_dir / "src" # same as project_root
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir / "src" # same as test_project_root
# test on our repo
monkeypatch.setattr("codeflash.optimization.optimizer.git_root_dir", lambda: repo_root)
args = Namespace()
args.benchmark = False
args.benchmarks_root = None
args.config_file = repo_root / "pyproject.toml"
args.file = repo_root / "codeflash/optimization/optimizer.py"
args.worktree = True
new_args = process_pyproject_config(args)
optimizer = Optimizer(new_args)
worktree_dir = repo_root / "worktree"
optimizer.mirror_paths_for_worktree_mode(worktree_dir)
assert optimizer.args.project_root == worktree_dir
assert optimizer.args.test_project_root == worktree_dir
assert optimizer.args.module_root == worktree_dir / "codeflash"
assert optimizer.args.tests_root == worktree_dir / "tests"
assert optimizer.args.file == worktree_dir / "codeflash/optimization/optimizer.py"
assert optimizer.test_cfg.tests_root == worktree_dir / "tests"
assert optimizer.test_cfg.project_root_path == worktree_dir # same as project_root
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir # same as test_project_root

1032
uv.lock

File diff suppressed because it is too large Load diff