Merge branch 'main' of github.com:codeflash-ai/codeflash into testgen/multi-files
This commit is contained in:
commit
1c7c2b88ba
86 changed files with 7953 additions and 2191 deletions
69
.github/workflows/e2e-async.yaml
vendored
Normal file
69
.github/workflows/e2e-async.yaml
vendored
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
name: E2E - Async
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**' # Trigger for all paths
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
async-optimization:
|
||||
# Dynamically determine if environment is needed only when workflow files change and contributor is external
|
||||
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 10
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
run: |
|
||||
# Check for any workflow changes
|
||||
if git diff --name-only "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
|
||||
# Get the PR author
|
||||
AUTHOR="${{ github.event.pull_request.user.login }}"
|
||||
echo "PR Author: $AUTHOR"
|
||||
|
||||
# Allowlist check
|
||||
if [[ "$AUTHOR" == "misrasaurabh1" || "$AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($AUTHOR). Proceeding."
|
||||
elif [[ "${{ github.event.pull_request.state }}" == "open" ]]; then
|
||||
echo "✅ PR triggered by 'pull_request_target' and is open. Assuming protection rules are in place. Proceeding."
|
||||
else
|
||||
echo "⛔ Unauthorized user ($AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: |
|
||||
uv sync
|
||||
|
||||
- name: Run Codeflash to optimize async code
|
||||
id: optimize_async_code
|
||||
run: |
|
||||
uv run python tests/scripts/end_to_end_test_async.py
|
||||
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 300
|
||||
EXPECTED_IMPROVEMENT_PCT: 70
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ jobs:
|
|||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 300
|
||||
EXPECTED_IMPROVEMENT_PCT: 40
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: 🛎️ Checkout
|
||||
|
|
|
|||
1
.github/workflows/unit-tests.yaml
vendored
1
.github/workflows/unit-tests.yaml
vendored
|
|
@ -24,7 +24,6 @@ jobs:
|
|||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
version: "0.5.30"
|
||||
|
||||
- name: install dependencies
|
||||
run: uv sync
|
||||
|
|
|
|||
30
.github/workflows/windows-unit-tests.yml
vendored
Normal file
30
.github/workflows/windows-unit-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
name: windows-unit-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
windows-unit-tests:
|
||||
continue-on-error: true
|
||||
runs-on: windows-latest
|
||||
env:
|
||||
PYTHONIOENCODING: utf-8
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.13"
|
||||
|
||||
- name: install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Unit tests
|
||||
run: uv run pytest tests/
|
||||
43
code_to_optimize/async_bubble_sort.py
Normal file
43
code_to_optimize/async_bubble_sort.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
async def async_sorter(lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
"""
|
||||
Async bubble sort implementation for testing.
|
||||
"""
|
||||
print("codeflash stdout: Async sorting list")
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
print(f"result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
class AsyncBubbleSorter:
|
||||
"""Class with async sorting method for testing."""
|
||||
|
||||
async def sorter(self, lst: List[Union[int, float]]) -> List[Union[int, float]]:
|
||||
"""
|
||||
Async bubble sort implementation within a class.
|
||||
"""
|
||||
print("codeflash stdout: AsyncBubbleSorter.sorter() called")
|
||||
|
||||
# Add some async delay
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
n = len(lst)
|
||||
for i in range(n):
|
||||
for j in range(0, n - i - 1):
|
||||
if lst[j] > lst[j + 1]:
|
||||
lst[j], lst[j + 1] = lst[j + 1], lst[j]
|
||||
|
||||
result = lst.copy()
|
||||
return result
|
||||
16
code_to_optimize/code_directories/async_e2e/main.py
Normal file
16
code_to_optimize/code_directories/async_e2e/main.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import time
|
||||
import asyncio
|
||||
|
||||
|
||||
async def retry_with_backoff(func, max_retries=3):
|
||||
if max_retries < 1:
|
||||
raise ValueError("max_retries must be at least 1")
|
||||
last_exception = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(0.0001 * attempt)
|
||||
raise last_exception
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
[tool.codeflash]
|
||||
disable-telemetry = true
|
||||
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
|
||||
module-root = "."
|
||||
test-framework = "pytest"
|
||||
tests-root = "tests"
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
[tool.codeflash]
|
||||
# All paths are relative to this pyproject.toml's directory.
|
||||
module-root = "src/app"
|
||||
tests-root = "src/tests"
|
||||
test-framework = "pytest"
|
||||
ignore-paths = []
|
||||
disable-telemetry = true
|
||||
formatter-cmds = ["disabled"]
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
def sorter(arr):
|
||||
print("codeflash stdout: Sorting list")
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
print(f"result: {arr}")
|
||||
return arr
|
||||
|
|
@ -102,6 +102,8 @@ class AiServiceClient:
|
|||
trace_id: str,
|
||||
num_candidates: int = 10,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
*,
|
||||
is_async: bool = False,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -133,6 +135,7 @@ class AiServiceClient:
|
|||
"repo_owner": git_repo_owner,
|
||||
"repo_name": git_repo_name,
|
||||
"n_candidates": N_CANDIDATES_EFFECTIVE,
|
||||
"is_async": is_async,
|
||||
}
|
||||
|
||||
logger.info("!lsp|Generating optimized candidates…")
|
||||
|
|
@ -299,6 +302,9 @@ class AiServiceClient:
|
|||
annotated_tests: str,
|
||||
optimization_id: str,
|
||||
original_explanation: str,
|
||||
original_throughput: str | None = None,
|
||||
optimized_throughput: str | None = None,
|
||||
throughput_improvement: str | None = None,
|
||||
) -> str:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -315,6 +321,9 @@ class AiServiceClient:
|
|||
- annotated_tests: str - test functions annotated with runtime
|
||||
- optimization_id: str - unique id of opt candidate
|
||||
- original_explanation: str - original_explanation generated for the opt candidate
|
||||
- original_throughput: str | None - throughput for the baseline code (operations per second)
|
||||
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
|
||||
- throughput_improvement: str | None - throughput improvement percentage
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -334,6 +343,9 @@ class AiServiceClient:
|
|||
"optimization_id": optimization_id,
|
||||
"original_explanation": original_explanation,
|
||||
"dependency_code": dependency_code,
|
||||
"original_throughput": original_throughput,
|
||||
"optimized_throughput": optimized_throughput,
|
||||
"throughput_improvement": throughput_improvement,
|
||||
}
|
||||
logger.info("loading|Generating explanation")
|
||||
console.rule()
|
||||
|
|
@ -488,6 +500,7 @@ class AiServiceClient:
|
|||
"test_index": test_index,
|
||||
"python_version": platform.python_version(),
|
||||
"codeflash_version": codeflash_version,
|
||||
"is_async": function_to_optimize.is_async,
|
||||
}
|
||||
try:
|
||||
response = self.make_ai_service_request("/testgen", payload=payload, timeout=600)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from packaging import version
|
||||
|
||||
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
|
||||
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
|
||||
CFAPI_BASE_URL = "http://localhost:3001"
|
||||
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
|
||||
console.rule()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import pickle
|
|||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
|
|
@ -143,12 +144,13 @@ class CodeflashTrace:
|
|||
print("Pickle limit reached")
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
normalized_file_path = Path(func.__code__.co_filename).as_posix()
|
||||
self.function_calls_data.append(
|
||||
(
|
||||
func.__name__,
|
||||
class_name,
|
||||
func.__module__,
|
||||
func.__code__.co_filename,
|
||||
normalized_file_path,
|
||||
benchmark_function_name,
|
||||
benchmark_module_path,
|
||||
benchmark_line_number,
|
||||
|
|
@ -169,12 +171,13 @@ class CodeflashTrace:
|
|||
# Add to the list of function calls without pickled args. Used for timing info only
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
normalized_file_path = Path(func.__code__.co_filename).as_posix()
|
||||
self.function_calls_data.append(
|
||||
(
|
||||
func.__name__,
|
||||
class_name,
|
||||
func.__module__,
|
||||
func.__code__.co_filename,
|
||||
normalized_file_path,
|
||||
benchmark_function_name,
|
||||
benchmark_module_path,
|
||||
benchmark_line_number,
|
||||
|
|
@ -192,12 +195,13 @@ class CodeflashTrace:
|
|||
# Add to the list of function calls with pickled args, to be used for replay tests
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
normalized_file_path = Path(func.__code__.co_filename).as_posix()
|
||||
self.function_calls_data.append(
|
||||
(
|
||||
func.__name__,
|
||||
class_name,
|
||||
func.__module__,
|
||||
func.__code__.co_filename,
|
||||
normalized_file_path,
|
||||
benchmark_function_name,
|
||||
benchmark_module_path,
|
||||
benchmark_line_number,
|
||||
|
|
|
|||
|
|
@ -30,19 +30,24 @@ def get_next_arg_and_return(
|
|||
cur = db.cursor()
|
||||
limit = num_to_get
|
||||
|
||||
normalized_file_path = Path(file_path).as_posix()
|
||||
|
||||
if class_name is not None:
|
||||
cursor = cur.execute(
|
||||
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
|
||||
(benchmark_function_name, function_name, file_path, class_name, limit),
|
||||
(benchmark_function_name, function_name, normalized_file_path, class_name, limit),
|
||||
)
|
||||
else:
|
||||
cursor = cur.execute(
|
||||
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
|
||||
(benchmark_function_name, function_name, file_path, limit),
|
||||
(benchmark_function_name, function_name, normalized_file_path, limit),
|
||||
)
|
||||
|
||||
while (val := cursor.fetchone()) is not None:
|
||||
yield val[9], val[10] # pickled_args, pickled_kwargs
|
||||
try:
|
||||
while (val := cursor.fetchone()) is not None:
|
||||
yield val[9], val[10] # pickled_args, pickled_kwargs
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def get_function_alias(module: str, function_name: str) -> str:
|
||||
|
|
@ -166,7 +171,7 @@ trace_file_path = r"{trace_file}"
|
|||
module_name = func.get("module_name")
|
||||
function_name = func.get("function_name")
|
||||
class_name = func.get("class_name")
|
||||
file_path = func.get("file_path")
|
||||
file_path = Path(func.get("file_path")).as_posix()
|
||||
benchmark_function_name = func.get("benchmark_function_name")
|
||||
function_properties = func.get("function_properties")
|
||||
if not class_name:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from argparse import SUPPRESS, ArgumentParser, Namespace
|
||||
|
|
@ -96,6 +97,12 @@ def parse_args() -> Namespace:
|
|||
)
|
||||
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
|
||||
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
|
||||
parser.add_argument(
|
||||
"--async",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enable optimization of async functions. By default, async functions are excluded from optimization.",
|
||||
)
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
|
|
@ -139,6 +146,14 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
|
|||
if env_utils.is_ci():
|
||||
args.no_pr = True
|
||||
|
||||
if getattr(args, "async", False) and importlib.util.find_spec("pytest_asyncio") is None:
|
||||
logger.warning(
|
||||
"Warning: The --async flag requires pytest-asyncio to be installed.\n"
|
||||
"Please install it using:\n"
|
||||
' pip install "codeflash[asyncio]"'
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import datetime
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
|
@ -11,13 +10,16 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
from rich.prompt import Confirm
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.code_utils.compat import codeflash_temp_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import argparse
|
||||
|
||||
|
||||
class CodeflashRunCheckpoint:
|
||||
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None: # noqa: S108
|
||||
def __init__(self, module_root: Path, checkpoint_dir: Path | None = None) -> None:
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = codeflash_temp_dir
|
||||
self.module_root = module_root
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
# Create a unique checkpoint file name
|
||||
|
|
@ -37,7 +39,7 @@ class CodeflashRunCheckpoint:
|
|||
"last_updated": time.time(),
|
||||
}
|
||||
|
||||
with self.checkpoint_path.open("w") as f:
|
||||
with self.checkpoint_path.open("w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(metadata) + "\n")
|
||||
|
||||
def add_function_to_checkpoint(
|
||||
|
|
@ -66,7 +68,7 @@ class CodeflashRunCheckpoint:
|
|||
**additional_info,
|
||||
}
|
||||
|
||||
with self.checkpoint_path.open("a") as f:
|
||||
with self.checkpoint_path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(function_data) + "\n")
|
||||
|
||||
# Update the metadata last_updated timestamp
|
||||
|
|
@ -75,7 +77,7 @@ class CodeflashRunCheckpoint:
|
|||
def _update_metadata_timestamp(self) -> None:
|
||||
"""Update the last_updated timestamp in the metadata."""
|
||||
# Read the first line (metadata)
|
||||
with self.checkpoint_path.open() as f:
|
||||
with self.checkpoint_path.open(encoding="utf-8") as f:
|
||||
metadata = json.loads(f.readline())
|
||||
rest_content = f.read()
|
||||
|
||||
|
|
@ -84,7 +86,7 @@ class CodeflashRunCheckpoint:
|
|||
|
||||
# Write all lines to a temporary file
|
||||
|
||||
with self.checkpoint_path.open("w") as f:
|
||||
with self.checkpoint_path.open("w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(metadata) + "\n")
|
||||
f.write(rest_content)
|
||||
|
||||
|
|
@ -94,7 +96,7 @@ class CodeflashRunCheckpoint:
|
|||
self.checkpoint_path.unlink(missing_ok=True)
|
||||
|
||||
for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
|
||||
with file.open() as f:
|
||||
with file.open(encoding="utf-8") as f:
|
||||
# Skip the first line (metadata)
|
||||
first_line = next(f)
|
||||
metadata = json.loads(first_line)
|
||||
|
|
@ -116,7 +118,7 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
|
|||
to_delete = []
|
||||
|
||||
for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
|
||||
with file.open() as f:
|
||||
with file.open(encoding="utf-8") as f:
|
||||
# Skip the first line (metadata)
|
||||
first_line = next(f)
|
||||
metadata = json.loads(first_line)
|
||||
|
|
@ -139,8 +141,8 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
|
|||
|
||||
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
|
||||
previous_checkpoint_functions = None
|
||||
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir(): # noqa: S108 #TODO: use the temp dir from codeutils-compat.py
|
||||
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp")) # noqa: S108
|
||||
if args.all and codeflash_temp_dir.is_dir():
|
||||
previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir)
|
||||
if previous_checkpoint_functions and Confirm.ask(
|
||||
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
|
||||
default=True,
|
||||
|
|
|
|||
|
|
@ -272,6 +272,8 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
if child.module is None:
|
||||
continue
|
||||
module = self.get_full_dotted_name(child.module)
|
||||
if isinstance(child.names, cst.ImportStar):
|
||||
continue
|
||||
for alias in child.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
name = alias.name.value
|
||||
|
|
@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
return transformed_module.code
|
||||
|
||||
|
||||
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
||||
try:
|
||||
module_path = module_name.replace(".", "/")
|
||||
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]
|
||||
|
||||
module_file = None
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
module_file = path
|
||||
break
|
||||
|
||||
if module_file is None:
|
||||
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
|
||||
return set()
|
||||
|
||||
with module_file.open(encoding="utf8") as f:
|
||||
module_code = f.read()
|
||||
|
||||
tree = ast.parse(module_code)
|
||||
|
||||
all_names = None
|
||||
for node in ast.walk(tree):
|
||||
if (
|
||||
isinstance(node, ast.Assign)
|
||||
and len(node.targets) == 1
|
||||
and isinstance(node.targets[0], ast.Name)
|
||||
and node.targets[0].id == "__all__"
|
||||
):
|
||||
if isinstance(node.value, (ast.List, ast.Tuple)):
|
||||
all_names = []
|
||||
for elt in node.value.elts:
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
|
||||
all_names.append(elt.value)
|
||||
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
|
||||
all_names.append(elt.s)
|
||||
break
|
||||
|
||||
if all_names is not None:
|
||||
return set(all_names)
|
||||
|
||||
public_names = set()
|
||||
for node in tree.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
if not node.name.startswith("_"):
|
||||
public_names.add(node.name)
|
||||
elif isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and not target.id.startswith("_"):
|
||||
public_names.add(target.id)
|
||||
elif isinstance(node, ast.AnnAssign):
|
||||
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
|
||||
public_names.add(node.target.id)
|
||||
elif isinstance(node, ast.Import) or (
|
||||
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
|
||||
):
|
||||
for alias in node.names:
|
||||
name = alias.asname or alias.name
|
||||
if not name.startswith("_"):
|
||||
public_names.add(name)
|
||||
|
||||
return public_names # noqa: TRY300
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving star import for {module_name}: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def add_needed_imports_from_module(
|
||||
src_module_code: str,
|
||||
dst_module_code: str,
|
||||
|
|
@ -468,9 +537,23 @@ def add_needed_imports_from_module(
|
|||
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
|
||||
):
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
|
||||
# Handle star imports by resolving them to actual symbol names
|
||||
if obj == "*":
|
||||
resolved_symbols = resolve_star_import(mod, project_root)
|
||||
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")
|
||||
|
||||
for symbol in resolved_symbols:
|
||||
if (
|
||||
f"{mod}.{symbol}" not in helper_functions_fqn
|
||||
and f"{mod}.{symbol}" not in dotted_import_collector.imports
|
||||
):
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
|
||||
else:
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
|
|
|
|||
|
|
@ -173,14 +173,14 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
|
|||
|
||||
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
|
||||
try:
|
||||
relative_path = file_path.relative_to(project_root_path)
|
||||
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
return relative_path.with_suffix("").as_posix().replace("/", ".")
|
||||
except ValueError:
|
||||
if traverse_up:
|
||||
parent = file_path.parent
|
||||
while parent not in (project_root_path, parent.parent):
|
||||
try:
|
||||
relative_path = file_path.relative_to(parent)
|
||||
relative_path = file_path.resolve().relative_to(parent.resolve())
|
||||
return relative_path.with_suffix("").as_posix().replace("/", ".")
|
||||
except ValueError:
|
||||
parent = parent.parent
|
||||
|
|
@ -245,8 +245,9 @@ def get_run_tmp_file(file_path: Path) -> Path:
|
|||
|
||||
|
||||
def path_belongs_to_site_packages(file_path: Path) -> bool:
|
||||
site_packages = [Path(p) for p in site.getsitepackages()]
|
||||
return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages)
|
||||
file_path_resolved = file_path.resolve()
|
||||
site_packages = [Path(p).resolve() for p in site.getsitepackages()]
|
||||
return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages)
|
||||
|
||||
|
||||
def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
|
||||
|
|
@ -268,14 +269,6 @@ def validate_python_code(code: str) -> str:
|
|||
return code
|
||||
|
||||
|
||||
def has_any_async_functions(code: str) -> bool:
|
||||
try:
|
||||
module = ast.parse(code)
|
||||
except SyntaxError:
|
||||
return False
|
||||
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))
|
||||
|
||||
|
||||
def cleanup_paths(paths: list[Path]) -> None:
|
||||
for path in paths:
|
||||
if path and path.exists():
|
||||
|
|
|
|||
167
codeflash/code_utils/codeflash_wrap_decorator.py
Normal file
167
codeflash/code_utils/codeflash_wrap_decorator.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
import dill as pickle
|
||||
|
||||
|
||||
class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py
|
||||
FUNCTION_CALL = (
|
||||
"function_call" # Correctness verification for a test function, checks input values and output values)
|
||||
)
|
||||
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
|
||||
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_utils/code_utils.py
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
|
||||
def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
||||
test_module = os.environ["CODEFLASH_TEST_MODULE"]
|
||||
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
|
||||
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
|
||||
|
||||
if test_module and test_function:
|
||||
return (test_module, test_class if test_class else None, test_function)
|
||||
|
||||
raise RuntimeError(
|
||||
"Test context environment variables not set - ensure tests are run through codeflash test runner"
|
||||
)
|
||||
|
||||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
codeflash_con = sqlite3.connect(db_path)
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, "
|
||||
"test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, "
|
||||
"runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs) # coroutine creation has some overhead, though it is very small
|
||||
counter = loop.time()
|
||||
return_value = await ret # let's measure the actual execution time of the code
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
print(f"!######{test_stdout_tag}######!")
|
||||
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps((args, kwargs, return_value))
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
test_module_name,
|
||||
test_class_name,
|
||||
test_name,
|
||||
function_name,
|
||||
loop_index,
|
||||
invocation_id,
|
||||
codeflash_duration,
|
||||
pickled_return_value,
|
||||
VerificationType.FUNCTION_CALL.value,
|
||||
),
|
||||
)
|
||||
codeflash_con.commit()
|
||||
codeflash_con.close()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_performance_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
|
||||
|
||||
test_module_name, test_class_name, test_name = extract_test_context_from_env()
|
||||
|
||||
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
|
||||
|
||||
if not hasattr(async_wrapper, "index"):
|
||||
async_wrapper.index = {}
|
||||
if test_id in async_wrapper.index:
|
||||
async_wrapper.index[test_id] += 1
|
||||
else:
|
||||
async_wrapper.index[test_id] = 0
|
||||
|
||||
codeflash_test_index = async_wrapper.index[test_id]
|
||||
invocation_id = f"{line_id}_{codeflash_test_index}"
|
||||
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
|
||||
|
||||
print(f"!$######{test_stdout_tag}######$!")
|
||||
|
||||
exception = None
|
||||
counter = loop.time()
|
||||
gc.disable()
|
||||
try:
|
||||
ret = func(*args, **kwargs)
|
||||
counter = loop.time()
|
||||
return_value = await ret
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
except Exception as e:
|
||||
codeflash_duration = int((loop.time() - counter) * 1_000_000_000)
|
||||
exception = e
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
return async_wrapper
|
||||
|
|
@ -3,6 +3,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
|||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
N_CANDIDATES = 5
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
MAX_TEST_FUNCTION_RUNS = 50
|
||||
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
|
||||
N_TESTS_TO_GENERATE = 2
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
|
|||
dependent_functions = set()
|
||||
for code_string in code_context.testgen_context.code_strings:
|
||||
ast_tree = ast.parse(code_string.code)
|
||||
dependent_functions.update({node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)})
|
||||
dependent_functions.update(
|
||||
{node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))}
|
||||
)
|
||||
|
||||
if main_function in dependent_functions:
|
||||
dependent_functions.discard(main_function)
|
||||
|
|
@ -43,17 +45,25 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
|
|||
def generate_candidates(source_code_path: Path) -> set[str]:
|
||||
"""Generate all the possible candidates for coverage data based on the source code path."""
|
||||
candidates = set()
|
||||
candidates.add(source_code_path.name)
|
||||
current_path = source_code_path.parent
|
||||
# Add the filename as a candidate
|
||||
name = source_code_path.name
|
||||
candidates.add(name)
|
||||
|
||||
last_added = source_code_path.name
|
||||
while current_path != current_path.parent:
|
||||
candidate_path = str(Path(current_path.name) / last_added)
|
||||
# Precompute parts for efficient candidate path construction
|
||||
parts = source_code_path.parts
|
||||
n = len(parts)
|
||||
|
||||
# Walk up the directory structure without creating Path objects or repeatedly converting to posix
|
||||
last_added = name
|
||||
# Start from the last parent and move up to the root, exclusive (skip the root itself)
|
||||
for i in range(n - 2, 0, -1):
|
||||
# Combine the ith part with the accumulated path (last_added)
|
||||
candidate_path = f"{parts[i]}/{last_added}"
|
||||
candidates.add(candidate_path)
|
||||
last_added = candidate_path
|
||||
current_path = current_path.parent
|
||||
|
||||
candidates.add(str(source_code_path))
|
||||
# Add the absolute posix path as a candidate
|
||||
candidates.add(source_code_path.as_posix())
|
||||
return candidates
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -32,9 +32,11 @@ class CommentMapper(ast.NodeVisitor):
|
|||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
self.context_stack.append(node.name)
|
||||
for inner_node in ast.walk(node):
|
||||
for inner_node in node.body:
|
||||
if isinstance(inner_node, ast.FunctionDef):
|
||||
self.visit_FunctionDef(inner_node)
|
||||
elif isinstance(inner_node, ast.AsyncFunctionDef):
|
||||
self.visit_AsyncFunctionDef(inner_node)
|
||||
self.context_stack.pop()
|
||||
return node
|
||||
|
||||
|
|
@ -50,6 +52,14 @@ class CommentMapper(ast.NodeVisitor):
|
|||
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
self._process_function_def_common(node)
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
self._process_function_def_common(node)
|
||||
return node
|
||||
|
||||
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
|
||||
self.context_stack.append(node.name)
|
||||
i = len(node.body) - 1
|
||||
test_qualified_name = ".".join(self.context_stack)
|
||||
|
|
@ -60,8 +70,9 @@ class CommentMapper(ast.NodeVisitor):
|
|||
j = len(line_node.body) - 1
|
||||
while j >= 0:
|
||||
compound_line_node: ast.stmt = line_node.body[j]
|
||||
internal_node: ast.AST
|
||||
for internal_node in ast.walk(compound_line_node):
|
||||
nodes_to_check = [compound_line_node]
|
||||
nodes_to_check.extend(getattr(compound_line_node, "body", []))
|
||||
for internal_node in nodes_to_check:
|
||||
if isinstance(internal_node, (ast.stmt, ast.Assign)):
|
||||
inv_id = str(i) + "_" + str(j)
|
||||
match_key = key + "#" + inv_id
|
||||
|
|
@ -75,7 +86,6 @@ class CommentMapper(ast.NodeVisitor):
|
|||
self.results[line_node.lineno] = self.get_comment(match_key)
|
||||
i -= 1
|
||||
self.context_stack.pop()
|
||||
return node
|
||||
|
||||
|
||||
def get_fn_call_linenos(
|
||||
|
|
@ -197,23 +207,41 @@ def add_runtime_comments_to_generated_tests(
|
|||
def remove_functions_from_generated_tests(
|
||||
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
|
||||
) -> GeneratedTestsList:
|
||||
# Pre-compile patterns for all function names to remove
|
||||
function_patterns = _compile_function_patterns(test_functions_to_remove)
|
||||
new_generated_tests = []
|
||||
|
||||
for generated_test in generated_tests.generated_tests:
|
||||
for test_function in test_functions_to_remove:
|
||||
function_pattern = re.compile(
|
||||
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
source = generated_test.generated_original_test_source
|
||||
|
||||
match = function_pattern.search(generated_test.generated_original_test_source)
|
||||
|
||||
if match is None or "@pytest.mark.parametrize" in match.group(0):
|
||||
continue
|
||||
|
||||
generated_test.generated_original_test_source = function_pattern.sub(
|
||||
"", generated_test.generated_original_test_source
|
||||
)
|
||||
# Apply all patterns without redundant searches
|
||||
for pattern in function_patterns:
|
||||
# Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls
|
||||
for match in pattern.finditer(source):
|
||||
# Skip if "@pytest.mark.parametrize" present
|
||||
# Only the matched function's code is targeted
|
||||
if "@pytest.mark.parametrize" in match.group(0):
|
||||
continue
|
||||
# Remove function from source
|
||||
# If match, remove the function by substitution in the source
|
||||
# Replace using start/end indices for efficiency
|
||||
start, end = match.span()
|
||||
source = source[:start] + source[end:]
|
||||
# After removal, break since .finditer() is from left to right, and only one match expected per function in source
|
||||
break
|
||||
|
||||
generated_test.generated_original_test_source = source
|
||||
new_generated_tests.append(generated_test)
|
||||
|
||||
return GeneratedTestsList(generated_tests=new_generated_tests)
|
||||
|
||||
|
||||
# Pre-compile all function removal regexes upfront for efficiency.
|
||||
def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]:
|
||||
return [
|
||||
re.compile(
|
||||
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
for func in test_functions_to_remove
|
||||
]
|
||||
|
|
|
|||
|
|
@ -18,10 +18,9 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
|
|||
if formatter_cmds[0] == "disabled":
|
||||
return return_code
|
||||
tmp_code = """print("hello world")"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".py") as f:
|
||||
f.write(tmp_code)
|
||||
f.flush()
|
||||
tmp_file = Path(f.name)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_file = Path(tmpdir) / "test_codeflash_formatter.py"
|
||||
tmp_file.write_text(tmp_code, encoding="utf-8")
|
||||
try:
|
||||
format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure)
|
||||
except Exception:
|
||||
|
|
@ -29,7 +28,7 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
|
|||
"⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.",
|
||||
error_on_exit=True,
|
||||
)
|
||||
return return_code
|
||||
return return_code
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
|
@ -121,7 +120,7 @@ def get_cached_gh_event_data() -> dict[str, Any]:
|
|||
event_path = os.getenv("GITHUB_EVENT_PATH")
|
||||
if not event_path:
|
||||
return {}
|
||||
with Path(event_path).open() as f:
|
||||
with Path(event_path).open(encoding="utf-8") as f:
|
||||
return json.load(f) # type: ignore # noqa
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
|
@ -9,15 +8,12 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import git
|
||||
from filelock import FileLock
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.compat import codeflash_cache_dir
|
||||
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from git import Repo
|
||||
|
||||
|
||||
|
|
@ -100,56 +96,15 @@ def get_patches_dir_for_project() -> Path:
|
|||
return Path(patches_dir / project_id)
|
||||
|
||||
|
||||
def get_patches_metadata() -> dict[str, Any]:
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
meta_file = project_patches_dir / "metadata.json"
|
||||
if meta_file.exists():
|
||||
with meta_file.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return {"id": get_git_project_id() or "", "patches": []}
|
||||
|
||||
|
||||
def save_patches_metadata(patch_metadata: dict) -> dict:
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
meta_file = project_patches_dir / "metadata.json"
|
||||
lock_file = project_patches_dir / "metadata.json.lock"
|
||||
|
||||
# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
|
||||
with FileLock(lock_file, timeout=10):
|
||||
metadata = get_patches_metadata()
|
||||
|
||||
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
|
||||
metadata["patches"].append(patch_metadata)
|
||||
|
||||
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
return patch_metadata
|
||||
|
||||
|
||||
def overwrite_patch_metadata(patches: list[dict]) -> bool:
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
meta_file = project_patches_dir / "metadata.json"
|
||||
lock_file = project_patches_dir / "metadata.json.lock"
|
||||
|
||||
with FileLock(lock_file, timeout=10):
|
||||
metadata = get_patches_metadata()
|
||||
metadata["patches"] = patches
|
||||
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||
return True
|
||||
|
||||
|
||||
def create_diff_patch_from_worktree(
|
||||
worktree_dir: Path,
|
||||
files: list[str],
|
||||
fto_name: Optional[str] = None,
|
||||
metadata_input: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
|
||||
) -> Optional[Path]:
|
||||
repository = git.Repo(worktree_dir, search_parent_directories=True)
|
||||
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
|
||||
|
||||
if not uni_diff_text:
|
||||
logger.warning("No changes found in worktree.")
|
||||
return {}
|
||||
return None
|
||||
|
||||
if not uni_diff_text.endswith("\n"):
|
||||
uni_diff_text += "\n"
|
||||
|
|
@ -157,14 +112,8 @@ def create_diff_patch_from_worktree(
|
|||
project_patches_dir = get_patches_dir_for_project()
|
||||
project_patches_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
|
||||
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
|
||||
patch_path = project_patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
|
||||
with patch_path.open("w", encoding="utf8") as f:
|
||||
f.write(uni_diff_text)
|
||||
|
||||
final_metadata = {"patch_path": str(patch_path)}
|
||||
if metadata_input:
|
||||
final_metadata.update(metadata_input)
|
||||
final_metadata = save_patches_metadata(final_metadata)
|
||||
|
||||
return final_metadata
|
||||
return patch_path
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import isort
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
|
|
@ -76,6 +78,10 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
call_node = node
|
||||
if isinstance(node.func, ast.Name):
|
||||
function_name = node.func.id
|
||||
|
||||
if self.function_object.is_async:
|
||||
return [test_node]
|
||||
|
||||
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
|
||||
node.args = [
|
||||
ast.Name(id=function_name, ctx=ast.Load()),
|
||||
|
|
@ -97,6 +103,9 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
if isinstance(node.func, ast.Attribute):
|
||||
function_to_test = node.func.attr
|
||||
if function_to_test == self.function_object.function_name:
|
||||
if self.function_object.is_async:
|
||||
return [test_node]
|
||||
|
||||
function_name = ast.unparse(node.func)
|
||||
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
|
||||
node.args = [
|
||||
|
|
@ -135,7 +144,10 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
|
||||
if node.name.startswith("test_"):
|
||||
did_update = False
|
||||
if self.test_framework == "unittest":
|
||||
if self.test_framework == "unittest" and platform.system() != "Windows":
|
||||
# Only add timeout decorator on non-Windows platforms
|
||||
# Windows doesn't support SIGALRM signal required by timeout_decorator
|
||||
|
||||
node.decorator_list.append(
|
||||
ast.Call(
|
||||
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
|
||||
|
|
@ -212,7 +224,9 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
args=[
|
||||
ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(value=f"{get_run_tmp_file(Path('test_return_values_'))}"),
|
||||
ast.Constant(
|
||||
value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}"
|
||||
),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_iteration", ctx=ast.Load()),
|
||||
conversion=-1,
|
||||
|
|
@ -283,6 +297,168 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
return node
|
||||
|
||||
|
||||
class AsyncCallInstrumenter(ast.NodeTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
module_path: str,
|
||||
test_framework: str,
|
||||
call_positions: list[CodePosition],
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> None:
|
||||
self.mode = mode
|
||||
self.function_object = function
|
||||
self.class_name = None
|
||||
self.only_function_name = function.function_name
|
||||
self.module_path = module_path
|
||||
self.test_framework = test_framework
|
||||
self.call_positions = call_positions
|
||||
self.did_instrument = False
|
||||
# Track function call count per test function
|
||||
self.async_call_counter: dict[str, int] = {}
|
||||
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
|
||||
self.class_name = function.top_level_parent_name
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
# Add timeout decorator for unittest test classes if needed
|
||||
if self.test_framework == "unittest":
|
||||
timeout_decorator = ast.Call(
|
||||
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
|
||||
args=[ast.Constant(value=15)],
|
||||
keywords=[],
|
||||
)
|
||||
for item in node.body:
|
||||
if (
|
||||
isinstance(item, ast.FunctionDef)
|
||||
and item.name.startswith("test_")
|
||||
and not any(
|
||||
isinstance(d, ast.Call)
|
||||
and isinstance(d.func, ast.Name)
|
||||
and d.func.id == "timeout_decorator.timeout"
|
||||
for d in item.decorator_list
|
||||
)
|
||||
):
|
||||
item.decorator_list.append(timeout_decorator)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
if not node.name.startswith("test_"):
|
||||
return node
|
||||
|
||||
return self._process_test_function(node)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
# Only process test functions
|
||||
if not node.name.startswith("test_"):
|
||||
return node
|
||||
|
||||
return self._process_test_function(node)
|
||||
|
||||
def _process_test_function(
|
||||
self, node: ast.AsyncFunctionDef | ast.FunctionDef
|
||||
) -> ast.AsyncFunctionDef | ast.FunctionDef:
|
||||
# Optimize the search for decorator presence
|
||||
if self.test_framework == "unittest":
|
||||
found_timeout = False
|
||||
for d in node.decorator_list:
|
||||
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
|
||||
if isinstance(d, ast.Call):
|
||||
f = d.func
|
||||
# Avoid attribute lookup if f is not ast.Name
|
||||
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
|
||||
found_timeout = True
|
||||
break
|
||||
if not found_timeout:
|
||||
timeout_decorator = ast.Call(
|
||||
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
|
||||
args=[ast.Constant(value=15)],
|
||||
keywords=[],
|
||||
)
|
||||
node.decorator_list.append(timeout_decorator)
|
||||
|
||||
# Initialize counter for this test function
|
||||
if node.name not in self.async_call_counter:
|
||||
self.async_call_counter[node.name] = 0
|
||||
|
||||
new_body = []
|
||||
|
||||
# Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
|
||||
for _i, stmt in enumerate(node.body):
|
||||
transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt)
|
||||
|
||||
if added_env_assignment:
|
||||
current_call_index = self.async_call_counter[node.name]
|
||||
self.async_call_counter[node.name] += 1
|
||||
|
||||
env_assignment = ast.Assign(
|
||||
targets=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_CURRENT_LINE_ID"),
|
||||
ctx=ast.Store(),
|
||||
)
|
||||
],
|
||||
value=ast.Constant(value=f"{current_call_index}"),
|
||||
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
|
||||
)
|
||||
new_body.append(env_assignment)
|
||||
self.did_instrument = True
|
||||
|
||||
new_body.append(transformed_stmt)
|
||||
|
||||
node.body = new_body
|
||||
return node
|
||||
|
||||
def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
|
||||
for node in ast.walk(stmt):
|
||||
if (
|
||||
isinstance(node, ast.Await)
|
||||
and isinstance(node.value, ast.Call)
|
||||
and self._is_target_call(node.value)
|
||||
and self._call_in_positions(node.value)
|
||||
):
|
||||
# Check if this call is in one of our target positions
|
||||
return stmt, True # Return original statement but signal we added env var
|
||||
|
||||
return stmt, False
|
||||
|
||||
def _is_target_call(self, call_node: ast.Call) -> bool:
|
||||
"""Check if this call node is calling our target async function."""
|
||||
if isinstance(call_node.func, ast.Name):
|
||||
return call_node.func.id == self.function_object.function_name
|
||||
if isinstance(call_node.func, ast.Attribute):
|
||||
return call_node.func.attr == self.function_object.function_name
|
||||
return False
|
||||
|
||||
def _call_in_positions(self, call_node: ast.Call) -> bool:
|
||||
if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"):
|
||||
return False
|
||||
|
||||
return node_in_call_position(call_node, self.call_positions)
|
||||
|
||||
# Optimized version: only walk child nodes for Await
|
||||
def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]:
|
||||
# Stack-based DFS, manual for relevant Await nodes
|
||||
stack = [stmt]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
# Favor direct ast.Await detection
|
||||
if isinstance(node, ast.Await):
|
||||
val = node.value
|
||||
if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val):
|
||||
return stmt, True
|
||||
# Use _fields instead of ast.walk for less allocations
|
||||
for fname in getattr(node, "_fields", ()):
|
||||
child = getattr(node, fname, None)
|
||||
if isinstance(child, list):
|
||||
stack.extend(child)
|
||||
elif isinstance(child, ast.AST):
|
||||
stack.append(child)
|
||||
return stmt, False
|
||||
|
||||
|
||||
class FunctionImportedAsVisitor(ast.NodeVisitor):
|
||||
"""Checks if a function has been imported as an alias. We only care about the alias then.
|
||||
|
||||
|
|
@ -310,6 +486,7 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
|
|||
file_path=self.function.file_path,
|
||||
starting_line=self.function.starting_line,
|
||||
ending_line=self.function.ending_line,
|
||||
is_async=self.function.is_async,
|
||||
)
|
||||
else:
|
||||
self.imported_as = FunctionToOptimize(
|
||||
|
|
@ -318,9 +495,69 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
|
|||
file_path=self.function.file_path,
|
||||
starting_line=self.function.starting_line,
|
||||
ending_line=self.function.ending_line,
|
||||
is_async=self.function.is_async,
|
||||
)
|
||||
|
||||
|
||||
def instrument_source_module_with_async_decorators(
|
||||
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
|
||||
) -> tuple[bool, str | None]:
|
||||
if not function_to_optimize.is_async:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
with source_path.open(encoding="utf8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
|
||||
|
||||
if decorator_added:
|
||||
return True, modified_code
|
||||
|
||||
except Exception:
|
||||
return False, None
|
||||
else:
|
||||
return False, None
|
||||
|
||||
|
||||
def inject_async_profiling_into_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: list[CodePosition],
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
tests_project_root: Path,
|
||||
test_framework: str,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling for async function calls by setting environment variables before each call."""
|
||||
with test_path.open(encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
|
||||
try:
|
||||
tree = ast.parse(test_code)
|
||||
except SyntaxError:
|
||||
logger.exception(f"Syntax error in code in file - {test_path}")
|
||||
return False, None
|
||||
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
|
||||
test_module_path = module_name_from_file_path(test_path, tests_project_root)
|
||||
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
|
||||
import_visitor.visit(tree)
|
||||
func = import_visitor.imported_as
|
||||
|
||||
async_instrumenter = AsyncCallInstrumenter(func, test_module_path, test_framework, call_positions, mode=mode)
|
||||
tree = async_instrumenter.visit(tree)
|
||||
|
||||
if not async_instrumenter.did_instrument:
|
||||
return False, None
|
||||
|
||||
# Add necessary imports
|
||||
new_imports = [ast.Import(names=[ast.alias(name="os")])]
|
||||
if test_framework == "unittest":
|
||||
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
|
||||
|
||||
tree.body = [*new_imports, *tree.body]
|
||||
return True, isort.code(ast.unparse(tree), float_to_top=True)
|
||||
|
||||
|
||||
def inject_profiling_into_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: list[CodePosition],
|
||||
|
|
@ -329,6 +566,11 @@ def inject_profiling_into_existing_test(
|
|||
test_framework: str,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
if function_to_optimize.is_async:
|
||||
return inject_async_profiling_into_existing_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, test_framework, mode
|
||||
)
|
||||
|
||||
with test_path.open(encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
try:
|
||||
|
|
@ -336,7 +578,7 @@ def inject_profiling_into_existing_test(
|
|||
except SyntaxError:
|
||||
logger.exception(f"Syntax error in code in file - {test_path}")
|
||||
return False, None
|
||||
# TODO: Pass the full name of function here, otherwise we can run into namespace clashes
|
||||
|
||||
test_module_path = module_name_from_file_path(test_path, tests_project_root)
|
||||
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
|
||||
import_visitor.visit(tree)
|
||||
|
|
@ -352,9 +594,11 @@ def inject_profiling_into_existing_test(
|
|||
new_imports.extend(
|
||||
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
|
||||
)
|
||||
if test_framework == "unittest":
|
||||
if test_framework == "unittest" and platform.system() != "Windows":
|
||||
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
|
||||
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
|
||||
additional_functions = [create_wrapper_function(mode)]
|
||||
|
||||
tree.body = [*new_imports, *additional_functions, *tree.body]
|
||||
return True, isort.code(ast.unparse(tree), float_to_top=True)
|
||||
|
||||
|
||||
|
|
@ -735,3 +979,162 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
decorator_list=[],
|
||||
returns=None,
|
||||
)
|
||||
|
||||
|
||||
class AsyncDecoratorAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds async decorator to async function definitions."""
|
||||
|
||||
def __init__(self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
|
||||
"""Initialize the transformer.
|
||||
|
||||
Args:
|
||||
----
|
||||
function: The FunctionToOptimize object representing the target async function.
|
||||
mode: The testing mode to determine which decorator to apply.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.function = function
|
||||
self.mode = mode
|
||||
self.qualified_name_parts = function.qualified_name.split(".")
|
||||
self.context_stack = []
|
||||
self.added_decorator = False
|
||||
|
||||
# Choose decorator based on mode
|
||||
self.decorator_name = (
|
||||
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
# Track when we enter a function
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
# Check if this is an async function and matches our target
|
||||
if original_node.asynchronous is not None and self.context_stack == self.qualified_name_parts:
|
||||
# Check if the decorator is already present
|
||||
has_decorator = any(
|
||||
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
|
||||
)
|
||||
|
||||
# Only add the decorator if it's not already there
|
||||
if not has_decorator:
|
||||
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
|
||||
|
||||
# Add our new decorator to the existing decorators
|
||||
updated_decorators = [new_decorator, *list(updated_node.decorators)]
|
||||
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))
|
||||
self.added_decorator = True
|
||||
|
||||
# Pop the context when we leave a function
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def _is_target_decorator(self, decorator_node: cst.Name | cst.Attribute | cst.Call) -> bool:
|
||||
"""Check if a decorator matches our target decorator name."""
|
||||
if isinstance(decorator_node, cst.Name):
|
||||
return decorator_node.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
}
|
||||
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
|
||||
return decorator_node.func.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
}
|
||||
return False
|
||||
|
||||
|
||||
class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds the import for async decorators."""
|
||||
|
||||
def __init__(self, mode: TestingMode = TestingMode.BEHAVIOR) -> None:
|
||||
self.mode = mode
|
||||
self.has_import = False
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
# Check if the async decorator import is already present
|
||||
if (
|
||||
isinstance(node.module, cst.Attribute)
|
||||
and isinstance(node.module.value, cst.Attribute)
|
||||
and isinstance(node.module.value.value, cst.Name)
|
||||
and node.module.value.value.value == "codeflash"
|
||||
and node.module.value.attr.value == "code_utils"
|
||||
and node.module.attr.value == "codeflash_wrap_decorator"
|
||||
and not isinstance(node.names, cst.ImportStar)
|
||||
):
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
for import_alias in node.names:
|
||||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
||||
# Choose import based on mode
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
|
||||
# Parse the import statement into a CST node
|
||||
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
|
||||
|
||||
# Add the import to the module's body
|
||||
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
|
||||
|
||||
|
||||
def add_async_decorator_to_function(
|
||||
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
|
||||
) -> tuple[str, bool]:
|
||||
"""Add async decorator to an async function definition.
|
||||
|
||||
Args:
|
||||
----
|
||||
source_code: The source code to modify.
|
||||
function: The FunctionToOptimize object representing the target async function.
|
||||
mode: The testing mode to determine which decorator to apply.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
Tuple of (modified_source_code, was_decorator_added).
|
||||
|
||||
"""
|
||||
if not function.is_async:
|
||||
return source_code, False
|
||||
|
||||
try:
|
||||
module = cst.parse_module(source_code)
|
||||
|
||||
# Add the decorator to the function
|
||||
decorator_transformer = AsyncDecoratorAdder(function, mode)
|
||||
module = module.visit(decorator_transformer)
|
||||
|
||||
# Add the import if decorator was added
|
||||
if decorator_transformer.added_decorator:
|
||||
import_transformer = AsyncDecoratorImportAdder(mode)
|
||||
module = module.visit(import_transformer)
|
||||
|
||||
return isort.code(module.code, float_to_top=True), decorator_transformer.added_decorator
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
|
||||
return source_code, False
|
||||
|
||||
|
||||
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:
|
||||
instrumented_filename = f"instrumented_{source_path.name}"
|
||||
return temp_dir / instrumented_filename
|
||||
|
|
|
|||
|
|
@ -219,6 +219,6 @@ def add_decorator_imports(function_to_optimize: FunctionToOptimize, code_context
|
|||
file.write(modified_code)
|
||||
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
|
||||
file_contents = function_to_optimize.file_path.read_text("utf-8")
|
||||
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
|
||||
modified_code = add_profile_enable(file_contents, line_profile_output_file.as_posix())
|
||||
function_to_optimize.file_path.write_text(modified_code, "utf-8")
|
||||
return line_profile_output_file
|
||||
|
|
|
|||
|
|
@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(
|
|||
|
||||
def get_first_top_level_function_or_method_ast(
|
||||
function_name: str, parents: list[FunctionParent], node: ast.AST
|
||||
) -> ast.FunctionDef | None:
|
||||
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
|
||||
if not parents:
|
||||
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
|
||||
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
|
||||
if result is not None:
|
||||
return result
|
||||
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
|
||||
if parents[0].type == "ClassDef" and (
|
||||
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
|
||||
):
|
||||
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
|
||||
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
|
||||
if result is not None:
|
||||
return result
|
||||
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -334,7 +334,9 @@ def extract_code_markdown_context_from_files(
|
|||
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
|
||||
),
|
||||
)
|
||||
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
|
||||
code_string_context = CodeString(
|
||||
code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve())
|
||||
)
|
||||
code_context_markdown.code_strings.append(code_string_context)
|
||||
# Extract code from file paths containing helpers of helpers
|
||||
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
|
||||
|
|
@ -365,7 +367,9 @@ def extract_code_markdown_context_from_files(
|
|||
project_root=project_root_path,
|
||||
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
|
||||
)
|
||||
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
|
||||
code_string_context = CodeString(
|
||||
code=code_context, file_path=file_path.resolve().relative_to(project_root_path.resolve())
|
||||
)
|
||||
code_context_markdown.code_strings.append(code_string_context)
|
||||
return code_context_markdown
|
||||
|
||||
|
|
|
|||
|
|
@ -279,7 +279,7 @@ class QualifiedFunctionUsageMarker:
|
|||
# Find class methods and add their containing classes and dunder methods
|
||||
for qualified_name in list(self.qualified_function_names):
|
||||
if "." in qualified_name:
|
||||
class_name, method_name = qualified_name.split(".", 1)
|
||||
class_name, _method_name = qualified_name.split(".", 1)
|
||||
|
||||
# Add the class itself
|
||||
expanded.add(class_name)
|
||||
|
|
@ -511,7 +511,7 @@ def revert_unused_helper_functions(
|
|||
if not unused_helpers:
|
||||
return
|
||||
|
||||
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
|
||||
logger.debug(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
|
||||
|
||||
# Group unused helpers by file path
|
||||
unused_helpers_by_file = defaultdict(list)
|
||||
|
|
@ -612,6 +612,34 @@ def _analyze_imports_in_optimized_code(
|
|||
return dict(imported_names_map)
|
||||
|
||||
|
||||
def find_target_node(
|
||||
root: ast.AST, function_to_optimize: FunctionToOptimize
|
||||
) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]:
|
||||
parents = function_to_optimize.parents
|
||||
node = root
|
||||
for parent in parents:
|
||||
# Fast loop: directly look for the matching ClassDef in node.body
|
||||
body = getattr(node, "body", None)
|
||||
if not body:
|
||||
return None
|
||||
for child in body:
|
||||
if isinstance(child, ast.ClassDef) and child.name == parent.name:
|
||||
node = child
|
||||
break
|
||||
else:
|
||||
return None
|
||||
|
||||
# Now node is either the root or the target parent class; look for function
|
||||
body = getattr(node, "body", None)
|
||||
if not body:
|
||||
return None
|
||||
target_name = function_to_optimize.function_name
|
||||
for child in body:
|
||||
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name:
|
||||
return child
|
||||
return None
|
||||
|
||||
|
||||
def detect_unused_helper_functions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
code_context: CodeOptimizationContext,
|
||||
|
|
@ -641,11 +669,7 @@ def detect_unused_helper_functions(
|
|||
optimized_ast = ast.parse(optimized_code)
|
||||
|
||||
# Find the optimized entrypoint function
|
||||
entrypoint_function_ast = None
|
||||
for node in ast.walk(optimized_ast):
|
||||
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
|
||||
entrypoint_function_ast = node
|
||||
break
|
||||
entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize)
|
||||
|
||||
if not entrypoint_function_ast:
|
||||
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import enum
|
||||
import hashlib
|
||||
import os
|
||||
import pickle
|
||||
|
|
@ -11,12 +12,11 @@ import subprocess
|
|||
import unittest
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
import pytest
|
||||
from pydantic.dataclasses import dataclass
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
|
@ -35,6 +35,22 @@ if TYPE_CHECKING:
|
|||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@final
|
||||
class PytestExitCode(enum.IntEnum): # don't need to import entire pytest just for this
|
||||
#: Tests passed.
|
||||
OK = 0
|
||||
#: Tests failed.
|
||||
TESTS_FAILED = 1
|
||||
#: pytest was interrupted.
|
||||
INTERRUPTED = 2
|
||||
#: An internal error got in the way.
|
||||
INTERNAL_ERROR = 3
|
||||
#: pytest was misused.
|
||||
USAGE_ERROR = 4
|
||||
#: pytest couldn't find tests.
|
||||
NO_TESTS_COLLECTED = 5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TestFunction:
|
||||
function_name: str
|
||||
|
|
@ -412,7 +428,7 @@ def discover_tests_pytest(
|
|||
error_section = match.group(1) if match else result.stdout
|
||||
|
||||
logger.warning(
|
||||
f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}"
|
||||
f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}\n {error_section}"
|
||||
)
|
||||
if "ModuleNotFoundError" in result.stdout:
|
||||
match = ImportErrorPattern.search(result.stdout).group()
|
||||
|
|
@ -420,7 +436,7 @@ def discover_tests_pytest(
|
|||
console.print(panel)
|
||||
|
||||
elif 0 <= exitcode <= 5:
|
||||
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}")
|
||||
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={PytestExitCode(exitcode).name}")
|
||||
else:
|
||||
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
|
||||
console.rule()
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class FunctionVisitor(cst.CSTVisitor):
|
|||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
is_async=bool(node.asynchronous),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -103,6 +104,15 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
|
||||
# Check if the async function has a return statement and add it to the list
|
||||
if function_has_return_statement(node) and not function_is_a_property(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
|
||||
)
|
||||
)
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> None:
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
|
||||
|
|
@ -122,6 +132,7 @@ class FunctionToOptimize:
|
|||
parents: A list of parent scopes, which could be classes or functions.
|
||||
starting_line: The starting line number of the function in the file.
|
||||
ending_line: The ending line number of the function in the file.
|
||||
is_async: Whether this function is defined as async.
|
||||
|
||||
The qualified_name property provides the full name of the function, including
|
||||
any parent class or function names. The qualified_name_with_modules_from_root
|
||||
|
|
@ -134,6 +145,7 @@ class FunctionToOptimize:
|
|||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
starting_line: Optional[int] = None
|
||||
ending_line: Optional[int] = None
|
||||
is_async: bool = False
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
|
|
@ -147,7 +159,11 @@ class FunctionToOptimize:
|
|||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
|
||||
if not self.parents:
|
||||
return self.function_name
|
||||
# Join all parent names with dots to handle nested classes properly
|
||||
parent_path = ".".join(parent.name for parent in self.parents)
|
||||
return f"{parent_path}.{self.function_name}"
|
||||
|
||||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
|
@ -163,6 +179,8 @@ def get_functions_to_optimize(
|
|||
project_root: Path,
|
||||
module_root: Path,
|
||||
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
|
||||
*,
|
||||
enable_async: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
|
||||
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
|
||||
"Only one of optimize_all, replay_test, or file should be provided"
|
||||
|
|
@ -216,7 +234,13 @@ def get_functions_to_optimize(
|
|||
ph("cli-optimizing-git-diff")
|
||||
functions = get_functions_within_git_diff(uncommitted_changes=False)
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
|
||||
functions,
|
||||
test_cfg.tests_root,
|
||||
ignore_paths,
|
||||
project_root,
|
||||
module_root,
|
||||
previous_checkpoint_functions,
|
||||
enable_async=enable_async,
|
||||
)
|
||||
|
||||
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
|
|
@ -411,11 +435,27 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
)
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
if self.class_name is None and node.name == self.function_name:
|
||||
self.is_top_level = True
|
||||
self.function_has_args = any(
|
||||
(
|
||||
bool(node.args.args),
|
||||
bool(node.args.kwonlyargs),
|
||||
bool(node.args.kwarg),
|
||||
bool(node.args.posonlyargs),
|
||||
bool(node.args.vararg),
|
||||
)
|
||||
)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
# iterate over the class methods
|
||||
if node.name == self.class_name:
|
||||
for body_node in node.body:
|
||||
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
|
||||
if (
|
||||
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and body_node.name == self.function_name
|
||||
):
|
||||
self.is_top_level = True
|
||||
if any(
|
||||
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
|
||||
|
|
@ -433,7 +473,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
# This way, if we don't have the class name, we can still find the static method
|
||||
for body_node in node.body:
|
||||
if (
|
||||
isinstance(body_node, ast.FunctionDef)
|
||||
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and body_node.name == self.function_name
|
||||
and body_node.lineno in {self.line_no, self.line_no + 1}
|
||||
and any(
|
||||
|
|
@ -535,7 +575,9 @@ def filter_functions(
|
|||
project_root: Path,
|
||||
module_root: Path,
|
||||
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
|
||||
disable_logs: bool = False, # noqa: FBT001, FBT002
|
||||
*,
|
||||
disable_logs: bool = False,
|
||||
enable_async: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
blocklist_funcs = get_blocklisted_functions()
|
||||
|
|
@ -555,6 +597,7 @@ def filter_functions(
|
|||
submodule_ignored_paths_count: int = 0
|
||||
blocklist_funcs_removed_count: int = 0
|
||||
previous_checkpoint_functions_removed_count: int = 0
|
||||
async_functions_removed_count: int = 0
|
||||
tests_root_str = str(tests_root)
|
||||
module_root_str = str(module_root)
|
||||
|
||||
|
|
@ -610,6 +653,15 @@ def filter_functions(
|
|||
functions_tmp.append(function)
|
||||
_functions = functions_tmp
|
||||
|
||||
if not enable_async:
|
||||
functions_tmp = []
|
||||
for function in _functions:
|
||||
if function.is_async:
|
||||
async_functions_removed_count += 1
|
||||
continue
|
||||
functions_tmp.append(function)
|
||||
_functions = functions_tmp
|
||||
|
||||
filtered_modified_functions[file_path] = _functions
|
||||
functions_count += len(_functions)
|
||||
|
||||
|
|
@ -623,6 +675,7 @@ def filter_functions(
|
|||
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
|
||||
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
|
||||
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
|
||||
"Async functions removed": (async_functions_removed_count, "bright_magenta"),
|
||||
}
|
||||
tree = Tree(Text("Ignored functions and files", style="bold"))
|
||||
for label, (count, color) in log_info.items():
|
||||
|
|
|
|||
|
|
@ -12,11 +12,8 @@ from pygls import uris
|
|||
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
|
||||
from codeflash.cli_cmds.cli import process_pyproject_config
|
||||
from codeflash.cli_cmds.console import code_print
|
||||
from codeflash.code_utils.git_worktree_utils import (
|
||||
create_diff_patch_from_worktree,
|
||||
get_patches_metadata,
|
||||
overwrite_patch_metadata,
|
||||
)
|
||||
from codeflash.code_utils.git_utils import git_root_dir
|
||||
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
filter_functions,
|
||||
|
|
@ -39,10 +36,17 @@ class OptimizableFunctionsParams:
|
|||
textDocument: types.TextDocumentIdentifier # noqa: N815
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionOptimizationInitParams:
|
||||
textDocument: types.TextDocumentIdentifier # noqa: N815
|
||||
functionName: str # noqa: N815
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionOptimizationParams:
|
||||
textDocument: types.TextDocumentIdentifier # noqa: N815
|
||||
functionName: str # noqa: N815
|
||||
task_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -59,7 +63,7 @@ class ValidateProjectParams:
|
|||
|
||||
@dataclass
|
||||
class OnPatchAppliedParams:
|
||||
patch_id: str
|
||||
task_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -111,7 +115,6 @@ def get_optimizable_functions(
|
|||
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
|
||||
) -> dict[str, list[str]]:
|
||||
file_path = Path(uris.to_fs_path(params.textDocument.uri))
|
||||
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
|
||||
if not server.optimizer:
|
||||
return {"status": "error", "message": "optimizer not initialized"}
|
||||
|
||||
|
|
@ -119,55 +122,15 @@ def get_optimizable_functions(
|
|||
server.optimizer.args.function = None # Always get ALL functions, not just one
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
|
||||
server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
|
||||
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
path_to_qualified_names = {}
|
||||
for functions in optimizable_funcs.values():
|
||||
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]
|
||||
|
||||
server.show_message_log(
|
||||
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
|
||||
)
|
||||
return path_to_qualified_names
|
||||
|
||||
|
||||
@server.feature("initializeFunctionOptimization")
|
||||
def initialize_function_optimization(
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationParams
|
||||
) -> dict[str, str]:
|
||||
file_path = Path(uris.to_fs_path(params.textDocument.uri))
|
||||
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
|
||||
|
||||
if server.optimizer is None:
|
||||
_initialize_optimizer_if_api_key_is_valid(server)
|
||||
|
||||
server.optimizer.worktree_mode()
|
||||
|
||||
original_args, _ = server.optimizer.original_args_and_test_cfg
|
||||
|
||||
server.optimizer.args.function = params.functionName
|
||||
original_relative_file_path = file_path.relative_to(original_args.project_root)
|
||||
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
|
||||
server.show_message_log(
|
||||
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
|
||||
)
|
||||
|
||||
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
if count == 0:
|
||||
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
|
||||
server.cleanup_the_optimizer()
|
||||
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
|
||||
|
||||
fto = optimizable_funcs.popitem()[1][0]
|
||||
server.optimizer.current_function_being_optimized = fto
|
||||
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
|
||||
return {"functionName": params.functionName, "status": "success"}
|
||||
|
||||
|
||||
def _find_pyproject_toml(workspace_path: str) -> Path | None:
|
||||
workspace_path_obj = Path(workspace_path)
|
||||
max_depth = 2
|
||||
|
|
@ -207,13 +170,18 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
|
|||
if pyproject_toml_path:
|
||||
server.prepare_optimizer_arguments(pyproject_toml_path)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No pyproject.toml found in workspace.",
|
||||
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
|
||||
return {"status": "error", "message": "No pyproject.toml found in workspace."}
|
||||
|
||||
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo.
|
||||
root = str(git_root_dir())
|
||||
|
||||
if getattr(params, "skip_validation", False):
|
||||
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
|
||||
return {
|
||||
"status": "success",
|
||||
"moduleRoot": server.args.module_root,
|
||||
"pyprojectPath": pyproject_toml_path,
|
||||
"root": root,
|
||||
}
|
||||
|
||||
server.show_message_log("Validating project...", "Info")
|
||||
config = is_valid_pyproject_toml(pyproject_toml_path)
|
||||
|
|
@ -234,7 +202,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
|
|||
except Exception:
|
||||
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
|
||||
|
||||
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
|
||||
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
|
||||
|
||||
|
||||
def _initialize_optimizer_if_api_key_is_valid(
|
||||
|
|
@ -296,78 +264,85 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
|
|||
return {"status": "error", "message": "something went wrong while saving the api key"}
|
||||
|
||||
|
||||
@server.feature("retrieveSuccessfulOptimizations")
|
||||
def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
|
||||
metadata = get_patches_metadata()
|
||||
return {"status": "success", "patches": metadata["patches"]}
|
||||
@server.feature("initializeFunctionOptimization")
|
||||
def initialize_function_optimization(
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
|
||||
) -> dict[str, str]:
|
||||
file_path = Path(uris.to_fs_path(params.textDocument.uri))
|
||||
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
|
||||
|
||||
if server.optimizer is None:
|
||||
_initialize_optimizer_if_api_key_is_valid(server)
|
||||
|
||||
@server.feature("onPatchApplied")
|
||||
def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]:
|
||||
# first remove the patch from the metadata
|
||||
metadata = get_patches_metadata()
|
||||
server.optimizer.worktree_mode()
|
||||
|
||||
deleted_patch_file = None
|
||||
new_patches = []
|
||||
for patch in metadata["patches"]:
|
||||
if patch["id"] == params.patch_id:
|
||||
deleted_patch_file = patch["patch_path"]
|
||||
continue
|
||||
new_patches.append(patch)
|
||||
original_args, _ = server.optimizer.original_args_and_test_cfg
|
||||
|
||||
# then remove the patch file
|
||||
if deleted_patch_file:
|
||||
overwrite_patch_metadata(new_patches)
|
||||
patch_path = Path(deleted_patch_file)
|
||||
patch_path.unlink(missing_ok=True)
|
||||
return {"status": "success"}
|
||||
return {"status": "error", "message": "Patch not found"}
|
||||
server.optimizer.args.function = params.functionName
|
||||
original_relative_file_path = file_path.relative_to(original_args.project_root)
|
||||
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
|
||||
server.show_message_log(
|
||||
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
|
||||
)
|
||||
|
||||
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
if count == 0:
|
||||
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
|
||||
server.cleanup_the_optimizer()
|
||||
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
|
||||
|
||||
fto = optimizable_funcs.popitem()[1][0]
|
||||
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
|
||||
if not module_prep_result:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "Failed to prepare module for optimization",
|
||||
}
|
||||
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
fto,
|
||||
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=fto.file_path,
|
||||
function_to_tests={},
|
||||
)
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
|
||||
server.current_optimization_init_result = initialization_result.unwrap()
|
||||
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
|
||||
|
||||
files = [function_optimizer.function_to_optimize.file_path]
|
||||
|
||||
_, _, original_helpers = server.current_optimization_init_result
|
||||
files.extend([str(helper_path) for helper_path in original_helpers])
|
||||
|
||||
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
|
||||
|
||||
|
||||
@server.feature("performFunctionOptimization")
|
||||
@server.thread()
|
||||
def perform_function_optimization( # noqa: PLR0911
|
||||
def perform_function_optimization(
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationParams
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
|
||||
current_function = server.optimizer.current_function_being_optimized
|
||||
|
||||
if not current_function:
|
||||
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "No function currently being optimized",
|
||||
}
|
||||
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
|
||||
if not module_prep_result:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "Failed to prepare module for optimization",
|
||||
}
|
||||
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
current_function,
|
||||
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=current_function.file_path,
|
||||
function_to_tests={},
|
||||
)
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
|
||||
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
|
||||
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
|
||||
function_optimizer = server.optimizer.current_function_optimizer
|
||||
current_function = function_optimizer.function_to_optimize
|
||||
|
||||
code_print(
|
||||
code_context.read_writable_code.flat,
|
||||
|
|
@ -447,22 +422,17 @@ def perform_function_optimization( # noqa: PLR0911
|
|||
|
||||
speedup = original_code_baseline.runtime / best_optimization.runtime
|
||||
|
||||
# get the original file path in the actual project (not in the worktree)
|
||||
original_args, _ = server.optimizer.original_args_and_test_cfg
|
||||
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
|
||||
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
|
||||
|
||||
metadata = create_diff_patch_from_worktree(
|
||||
server.optimizer.current_worktree,
|
||||
relative_file_paths,
|
||||
metadata_input={
|
||||
"fto_name": function_to_optimize_qualified_name,
|
||||
"explanation": best_optimization.explanation_v2,
|
||||
"file_path": str(original_file_path),
|
||||
"speedup": speedup,
|
||||
},
|
||||
patch_path = create_diff_patch_from_worktree(
|
||||
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
|
||||
)
|
||||
|
||||
if not patch_path:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "Failed to create a patch for optimization",
|
||||
}
|
||||
|
||||
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
|
||||
|
||||
return {
|
||||
|
|
@ -470,8 +440,8 @@ def perform_function_optimization( # noqa: PLR0911
|
|||
"status": "success",
|
||||
"message": "Optimization completed successfully",
|
||||
"extra": f"Speedup: {speedup:.2f}x faster",
|
||||
"patch_file": metadata["patch_path"],
|
||||
"patch_id": metadata["id"],
|
||||
"patch_file": str(patch_path),
|
||||
"task_id": params.task_id,
|
||||
"explanation": best_optimization.explanation_v2,
|
||||
}
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -29,15 +29,15 @@ def tree_to_markdown(tree: Tree, level: int = 0) -> str:
|
|||
|
||||
|
||||
def report_to_markdown_table(report: dict[TestType, dict[str, int]], title: str) -> str:
|
||||
lines = ["| Test Type | Passed ✅ | Failed ❌ |", "|-----------|--------|--------|"]
|
||||
lines = ["| Test Type | Passed ✅ |", "|-----------|--------|"]
|
||||
for test_type in TestType:
|
||||
if test_type is TestType.INIT_STATE_TEST:
|
||||
continue
|
||||
passed = report[test_type]["passed"]
|
||||
failed = report[test_type]["failed"]
|
||||
if passed == 0 and failed == 0:
|
||||
# failed = report[test_type]["failed"]
|
||||
if passed == 0:
|
||||
continue
|
||||
lines.append(f"| {test_type.to_name()} | {passed} | {failed} |")
|
||||
lines.append(f"| {test_type.to_name()} | {passed} |")
|
||||
table = "\n".join(lines)
|
||||
if title:
|
||||
return f"### {title}\n{table}"
|
||||
|
|
|
|||
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
|||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable
|
||||
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.lsp.lsp_message import LspTextMessage
|
||||
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
|
||||
|
||||
root_logger = None
|
||||
|
||||
|
|
@ -43,16 +43,19 @@ def add_heading_tags(msg: str, tags: LspMessageTags) -> str:
|
|||
return msg
|
||||
|
||||
|
||||
def extract_tags(msg: str) -> tuple[Optional[LspMessageTags], str]:
|
||||
def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
|
||||
delimiter = "|"
|
||||
parts = msg.split(delimiter)
|
||||
if len(parts) == 2:
|
||||
first_delim_idx = msg.find(delimiter)
|
||||
if first_delim_idx != -1 and msg.count(delimiter) == 1:
|
||||
tags_str = msg[:first_delim_idx]
|
||||
content = msg[first_delim_idx + 1 :]
|
||||
tags = {tag.strip() for tag in tags_str.split(",")}
|
||||
message_tags = LspMessageTags()
|
||||
tags = {tag.strip() for tag in parts[0].split(",")}
|
||||
if "!lsp" in tags:
|
||||
message_tags.not_lsp = True
|
||||
# manually check and set to avoid repeated membership tests
|
||||
if "lsp" in tags:
|
||||
message_tags.lsp = True
|
||||
if "!lsp" in tags:
|
||||
message_tags.not_lsp = True
|
||||
if "force_lsp" in tags:
|
||||
message_tags.force_lsp = True
|
||||
if "loading" in tags:
|
||||
|
|
@ -67,9 +70,9 @@ def extract_tags(msg: str) -> tuple[Optional[LspMessageTags], str]:
|
|||
message_tags.h3 = True
|
||||
if "h4" in tags:
|
||||
message_tags.h4 = True
|
||||
return message_tags, delimiter.join(parts[1:])
|
||||
return message_tags, content
|
||||
|
||||
return None, msg
|
||||
return LspMessageTags(), msg
|
||||
|
||||
|
||||
supported_lsp_log_levels = ("info", "debug")
|
||||
|
|
@ -86,31 +89,32 @@ def enhanced_log(
|
|||
actual_log_fn(msg, *args, **kwargs)
|
||||
return
|
||||
|
||||
is_lsp_json_message = msg.startswith('{"type"')
|
||||
is_lsp_json_message = msg.startswith(message_delimiter) and msg.endswith(message_delimiter)
|
||||
is_normal_text_message = not is_lsp_json_message
|
||||
|
||||
# extract tags only from the text messages (not the json ones)
|
||||
tags, clean_msg = extract_tags(msg) if is_normal_text_message else (None, msg)
|
||||
# Extract tags only from text messages
|
||||
tags, clean_msg = extract_tags(msg) if is_normal_text_message else (LspMessageTags(), msg)
|
||||
|
||||
lsp_enabled = is_LSP_enabled()
|
||||
lsp_only = tags and tags.lsp
|
||||
|
||||
if not lsp_enabled and not lsp_only:
|
||||
# normal logging
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
|
||||
#### LSP mode ####
|
||||
final_tags = tags if tags else LspMessageTags()
|
||||
|
||||
unsupported_level = level not in supported_lsp_log_levels
|
||||
if not final_tags.force_lsp and (final_tags.not_lsp or unsupported_level):
|
||||
return
|
||||
|
||||
# ---- Normal logging path ----
|
||||
if not tags.lsp:
|
||||
if not lsp_enabled: # LSP disabled
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
if tags.not_lsp: # explicitly marked as not for LSP
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
if unsupported_level and not tags.force_lsp: # unsupported level
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
|
||||
# ---- LSP logging path ----
|
||||
if is_normal_text_message:
|
||||
clean_msg = add_heading_tags(clean_msg, final_tags)
|
||||
clean_msg = add_highlight_tags(clean_msg, final_tags)
|
||||
clean_msg = LspTextMessage(text=clean_msg, takes_time=final_tags.loading).serialize()
|
||||
clean_msg = add_heading_tags(clean_msg, tags)
|
||||
clean_msg = add_highlight_tags(clean_msg, tags)
|
||||
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
|
||||
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ from codeflash.lsp.helpers import replace_quotes_with_backticks, simplify_worktr
|
|||
json_primitive_types = (str, float, int, bool)
|
||||
max_code_lines_before_collapse = 45
|
||||
|
||||
# \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
|
||||
message_delimiter = "\u241f"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMessage:
|
||||
|
|
@ -32,12 +35,8 @@ class LspMessage:
|
|||
|
||||
def serialize(self) -> str:
|
||||
data = self._loop_through(asdict(self))
|
||||
# Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it
|
||||
ordered = {"type": self.type(), **data}
|
||||
return (
|
||||
json.dumps(ordered)
|
||||
+ "\u241f" # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
|
||||
)
|
||||
return message_delimiter + json.dumps(ordered) + message_delimiter
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from pygls.server import LanguageServer
|
|||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
|
|
@ -22,6 +23,7 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
self.optimizer: Optimizer | None = None
|
||||
self.args_processed_before: bool = False
|
||||
self.args = None
|
||||
self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None
|
||||
|
||||
def prepare_optimizer_arguments(self, config_file: Path) -> None:
|
||||
from codeflash.cli_cmds.cli import parse_args
|
||||
|
|
@ -57,6 +59,7 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
self.lsp.notify("window/logMessage", log_params)
|
||||
|
||||
def cleanup_the_optimizer(self) -> None:
|
||||
self.current_optimization_init_result = None
|
||||
if not self.optimizer:
|
||||
return
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class BestOptimization(BaseModel):
|
|||
winning_benchmarking_test_results: TestResults
|
||||
winning_replay_benchmarking_test_results: Optional[TestResults] = None
|
||||
line_profiler_test_results: dict
|
||||
async_throughput: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -208,7 +209,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
|
||||
f"```python{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
|
||||
for code_string in self.code_strings
|
||||
]
|
||||
)
|
||||
|
|
@ -277,6 +278,7 @@ class OptimizedCandidateResult(BaseModel):
|
|||
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
|
||||
optimization_candidate_index: int
|
||||
total_candidate_timing: int
|
||||
async_throughput: Optional[int] = None
|
||||
|
||||
|
||||
class GeneratedTests(BaseModel):
|
||||
|
|
@ -383,6 +385,7 @@ class OriginalCodeBaseline(BaseModel):
|
|||
line_profile_results: dict
|
||||
runtime: int
|
||||
coverage_results: Optional[CoverageData]
|
||||
async_throughput: Optional[int] = None
|
||||
|
||||
|
||||
class CoverageStatus(Enum):
|
||||
|
|
@ -545,6 +548,7 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
# also we don't support deletion of test results elements - caution is advised
|
||||
test_results: list[FunctionTestInvocation] = []
|
||||
test_result_idx: dict[str, int] = {}
|
||||
perf_stdout: Optional[str] = None
|
||||
|
||||
def add(self, function_test_invocation: FunctionTestInvocation) -> None:
|
||||
unique_id = function_test_invocation.unique_invocation_loop_id
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from codeflash.code_utils.code_utils import (
|
|||
diff_length,
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
has_any_async_functions,
|
||||
module_name_from_file_path,
|
||||
restore_conftest,
|
||||
unified_diff_strings,
|
||||
|
|
@ -85,14 +84,20 @@ from codeflash.models.models import (
|
|||
TestType,
|
||||
)
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.critic import (
|
||||
coverage_critic,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
throughput_gain,
|
||||
)
|
||||
from codeflash.result.explanation import Explanation
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.concolic_testing import generate_concolic_tests
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
from codeflash.verification.parse_test_output import parse_test_results
|
||||
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
|
@ -199,7 +204,7 @@ class FunctionOptimizer:
|
|||
test_cfg: TestConfig,
|
||||
function_to_optimize_source_code: str = "",
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
|
||||
aiservice_client: AiServiceClient | None = None,
|
||||
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
|
|
@ -259,11 +264,6 @@ class FunctionOptimizer:
|
|||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
async_code = any(
|
||||
has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings
|
||||
)
|
||||
if async_code:
|
||||
return Failure("Codeflash does not support async functions in the code to optimize.")
|
||||
# Random here means that we still attempt optimization with a fractional chance to see if
|
||||
# last time we could not find an optimization, maybe this time we do.
|
||||
# Random is before as a performance optimization, swapping the two 'and' statements has the same effect
|
||||
|
|
@ -304,7 +304,7 @@ class FunctionOptimizer:
|
|||
]
|
||||
|
||||
with progress_bar(
|
||||
f"Generating new tests and optimizations for function {self.function_to_optimize.function_name}",
|
||||
f"Generating new tests and optimizations for function '{self.function_to_optimize.function_name}'",
|
||||
transient=True,
|
||||
revert_to_print=bool(get_pr_number()),
|
||||
):
|
||||
|
|
@ -587,7 +587,11 @@ class FunctionOptimizer:
|
|||
tree = Tree(f"Candidate #{candidate_index} - Runtime Information ⌛")
|
||||
benchmark_tree = None
|
||||
if speedup_critic(
|
||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
|
||||
candidate_result,
|
||||
original_code_baseline.runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
) and quantity_of_tests_critic(candidate_result):
|
||||
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
|
||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||
|
|
@ -598,6 +602,17 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
if (
|
||||
original_code_baseline.async_throughput is not None
|
||||
and candidate_result.async_throughput is not None
|
||||
):
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_code_baseline.async_throughput,
|
||||
optimized_throughput=candidate_result.async_throughput,
|
||||
)
|
||||
tree.add(f"Original async throughput: {original_code_baseline.async_throughput} executions")
|
||||
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
|
||||
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
|
||||
line_profile_test_results = self.line_profiler_step(
|
||||
code_context=code_context,
|
||||
original_helper_code=original_helper_code,
|
||||
|
|
@ -633,6 +648,7 @@ class FunctionOptimizer:
|
|||
replay_performance_gain=replay_perf_gain if self.args.benchmark else None,
|
||||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
async_throughput=candidate_result.async_throughput,
|
||||
)
|
||||
valid_optimizations.append(best_optimization)
|
||||
# queue corresponding refined optimization for best optimization
|
||||
|
|
@ -657,6 +673,15 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
if (
|
||||
original_code_baseline.async_throughput is not None
|
||||
and candidate_result.async_throughput is not None
|
||||
):
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_code_baseline.async_throughput,
|
||||
optimized_throughput=candidate_result.async_throughput,
|
||||
)
|
||||
tree.add(f"Throughput gain: {throughput_gain_value * 100:.1f}%")
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspMarkdownMessage(markdown=tree_to_markdown(tree)))
|
||||
|
|
@ -700,6 +725,7 @@ class FunctionOptimizer:
|
|||
replay_performance_gain=valid_opt.replay_performance_gain,
|
||||
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
|
||||
async_throughput=valid_opt.async_throughput,
|
||||
)
|
||||
valid_candidates_with_shorter_code.append(new_best_opt)
|
||||
diff_lens_list.append(
|
||||
|
|
@ -1079,6 +1105,7 @@ class FunctionOptimizer:
|
|||
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
|
||||
n_candidates,
|
||||
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
)
|
||||
future_candidates_exp = None
|
||||
|
||||
|
|
@ -1094,6 +1121,7 @@ class FunctionOptimizer:
|
|||
self.function_trace_id[:-4] + "EXP1",
|
||||
n_candidates,
|
||||
ExperimentMetadata(id=self.experiment_id, group="experiment"),
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
)
|
||||
futures.append(future_candidates_exp)
|
||||
|
||||
|
|
@ -1280,6 +1308,8 @@ class FunctionOptimizer:
|
|||
function_name=function_to_optimize_qualified_name,
|
||||
file_path=self.function_to_optimize.file_path,
|
||||
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_async_throughput=best_optimization.async_throughput,
|
||||
)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
|
|
@ -1362,6 +1392,23 @@ class FunctionOptimizer:
|
|||
original_runtimes_all=original_runtime_by_test,
|
||||
optimized_runtimes_all=optimized_runtime_by_test,
|
||||
)
|
||||
original_throughput_str = None
|
||||
optimized_throughput_str = None
|
||||
throughput_improvement_str = None
|
||||
|
||||
if (
|
||||
self.function_to_optimize.is_async
|
||||
and original_code_baseline.async_throughput is not None
|
||||
and best_optimization.async_throughput is not None
|
||||
):
|
||||
original_throughput_str = f"{original_code_baseline.async_throughput} operations/second"
|
||||
optimized_throughput_str = f"{best_optimization.async_throughput} operations/second"
|
||||
throughput_improvement_value = throughput_gain(
|
||||
original_throughput=original_code_baseline.async_throughput,
|
||||
optimized_throughput=best_optimization.async_throughput,
|
||||
)
|
||||
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
|
||||
|
||||
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
|
||||
source_code=code_context.read_writable_code.flat,
|
||||
dependency_code=code_context.read_only_context_code,
|
||||
|
|
@ -1375,6 +1422,9 @@ class FunctionOptimizer:
|
|||
annotated_tests=generated_tests_str,
|
||||
optimization_id=best_optimization.candidate.optimization_id,
|
||||
original_explanation=best_optimization.candidate.explanation,
|
||||
original_throughput=original_throughput_str,
|
||||
optimized_throughput=optimized_throughput_str,
|
||||
throughput_improvement=throughput_improvement_str,
|
||||
)
|
||||
new_explanation = Explanation(
|
||||
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
|
||||
|
|
@ -1385,6 +1435,8 @@ class FunctionOptimizer:
|
|||
function_name=explanation.function_name,
|
||||
file_path=explanation.file_path,
|
||||
benchmark_details=explanation.benchmark_details,
|
||||
original_async_throughput=explanation.original_async_throughput,
|
||||
best_async_throughput=explanation.best_async_throughput,
|
||||
)
|
||||
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
|
||||
|
||||
|
|
@ -1475,6 +1527,17 @@ class FunctionOptimizer:
|
|||
|
||||
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
|
||||
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(f"Applied async instrumentation to {self.function_to_optimize.file_path}")
|
||||
|
||||
# Instrument codeflash capture
|
||||
with progress_bar("Running tests to establish original code behavior..."):
|
||||
try:
|
||||
|
|
@ -1514,15 +1577,38 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
with progress_bar("Running performance benchmarks..."):
|
||||
benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
)
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
instrument_source_module_with_async_decorators,
|
||||
)
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(
|
||||
f"Applied async performance instrumentation to {self.function_to_optimize.file_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
)
|
||||
finally:
|
||||
if self.function_to_optimize.is_async:
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
self.function_to_optimize.file_path,
|
||||
)
|
||||
else:
|
||||
benchmarking_results = TestResults()
|
||||
start_time: float = time.time()
|
||||
|
|
@ -1570,12 +1656,20 @@ class FunctionOptimizer:
|
|||
|
||||
loop_count = max([int(result.loop_index) for result in benchmarking_results.test_results])
|
||||
logger.info(
|
||||
f"h2|⌚ Original code summed runtime measured over {loop_count} loop{'s' if loop_count > 1 else ''}: "
|
||||
f"{humanize_runtime(total_timing)} per full loop"
|
||||
f"h3|⌚ Original code summed runtime measured over '{loop_count}' loop{'s' if loop_count > 1 else ''}: "
|
||||
f"'{humanize_runtime(total_timing)}' per full loop"
|
||||
)
|
||||
console.rule()
|
||||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
|
||||
async_throughput = None
|
||||
if self.function_to_optimize.is_async:
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
|
||||
console.rule()
|
||||
|
||||
if self.args.benchmark:
|
||||
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -1589,6 +1683,7 @@ class FunctionOptimizer:
|
|||
runtime=total_timing,
|
||||
coverage_results=coverage_results,
|
||||
line_profile_results=line_profile_results,
|
||||
async_throughput=async_throughput,
|
||||
),
|
||||
functions_to_remove,
|
||||
)
|
||||
|
|
@ -1617,6 +1712,21 @@ class FunctionOptimizer:
|
|||
candidate_helper_code = {}
|
||||
for module_abspath in original_helper_code:
|
||||
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
instrument_source_module_with_async_decorators,
|
||||
)
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(
|
||||
f"Applied async behavioral instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(
|
||||
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
|
||||
|
|
@ -1654,14 +1764,37 @@ class FunctionOptimizer:
|
|||
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
|
||||
|
||||
if test_framework == "pytest":
|
||||
candidate_benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=optimization_candidate_index,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
)
|
||||
# For async functions, instrument at definition site for performance benchmarking
|
||||
if self.function_to_optimize.is_async:
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
instrument_source_module_with_async_decorators,
|
||||
)
|
||||
|
||||
success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE
|
||||
)
|
||||
if success and instrumented_source:
|
||||
with self.function_to_optimize.file_path.open("w", encoding="utf8") as f:
|
||||
f.write(instrumented_source)
|
||||
logger.debug(
|
||||
f"Applied async performance instrumentation to {self.function_to_optimize.file_path} for candidate {optimization_candidate_index}"
|
||||
)
|
||||
|
||||
try:
|
||||
candidate_benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=optimization_candidate_index,
|
||||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
)
|
||||
finally:
|
||||
# Restore original source if we instrumented it
|
||||
if self.function_to_optimize.is_async:
|
||||
self.write_code_and_helpers(
|
||||
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
loop_count = (
|
||||
max(all_loop_indices)
|
||||
if (
|
||||
|
|
@ -1697,6 +1830,14 @@ class FunctionOptimizer:
|
|||
console.rule()
|
||||
|
||||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
|
||||
candidate_async_throughput = None
|
||||
if self.function_to_optimize.is_async:
|
||||
candidate_async_throughput = calculate_function_throughput_from_test_results(
|
||||
candidate_benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
|
||||
|
||||
if self.args.benchmark:
|
||||
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -1716,6 +1857,7 @@ class FunctionOptimizer:
|
|||
else None,
|
||||
optimization_candidate_index=optimization_candidate_index,
|
||||
total_candidate_timing=total_candidate_timing,
|
||||
async_throughput=candidate_async_throughput,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1807,8 +1949,10 @@ class FunctionOptimizer:
|
|||
coverage_database_file=coverage_database_file,
|
||||
coverage_config_file=coverage_config_file,
|
||||
)
|
||||
else:
|
||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
if testing_type == TestingMode.PERFORMANCE:
|
||||
results.perf_stdout = run_result.stdout
|
||||
return results, coverage_results
|
||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
return results, coverage_results
|
||||
|
||||
def submit_test_generation_tasks(
|
||||
|
|
@ -1864,7 +2008,6 @@ class FunctionOptimizer:
|
|||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
|
||||
) -> dict:
|
||||
try:
|
||||
logger.info("Running line profiling to identify performance bottlenecks…")
|
||||
console.rule()
|
||||
|
||||
test_env = self.get_test_env(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from codeflash.cli_cmds.console import console, logger, progress_bar
|
|||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
|
||||
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
|
||||
from codeflash.code_utils.git_utils import check_running_in_git_repo
|
||||
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
|
||||
from codeflash.code_utils.git_worktree_utils import (
|
||||
create_detached_worktree,
|
||||
create_diff_patch_from_worktree,
|
||||
|
|
@ -134,12 +134,13 @@ class Optimizer:
|
|||
project_root=self.args.project_root,
|
||||
module_root=self.args.module_root,
|
||||
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
|
||||
enable_async=getattr(self.args, "async", False),
|
||||
)
|
||||
|
||||
def create_function_optimizer(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_source_code: str | None = "",
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
|
||||
|
|
@ -349,13 +350,12 @@ class Optimizer:
|
|||
relative_file_paths = [
|
||||
code_string.file_path for code_string in read_writable_code.code_strings
|
||||
]
|
||||
metadata = create_diff_patch_from_worktree(
|
||||
patch_path = create_diff_patch_from_worktree(
|
||||
self.current_worktree,
|
||||
relative_file_paths,
|
||||
fto_name=function_to_optimize.qualified_name,
|
||||
metadata_input={},
|
||||
)
|
||||
self.patch_files.append(metadata["patch_path"])
|
||||
self.patch_files.append(patch_path)
|
||||
if i < len(functions_to_optimize) - 1:
|
||||
create_worktree_snapshot_commit(
|
||||
self.current_worktree,
|
||||
|
|
@ -442,39 +442,50 @@ class Optimizer:
|
|||
logger.warning("Failed to create worktree. Skipping optimization.")
|
||||
return
|
||||
self.current_worktree = worktree_dir
|
||||
self.mutate_args_for_worktree_mode(worktree_dir)
|
||||
self.mirror_paths_for_worktree_mode(worktree_dir)
|
||||
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
|
||||
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
|
||||
saved_args = copy.deepcopy(self.args)
|
||||
saved_test_cfg = copy.deepcopy(self.test_cfg)
|
||||
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
|
||||
def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None:
|
||||
original_args = copy.deepcopy(self.args)
|
||||
original_test_cfg = copy.deepcopy(self.test_cfg)
|
||||
self.original_args_and_test_cfg = (original_args, original_test_cfg)
|
||||
|
||||
project_root = self.args.project_root
|
||||
module_root = self.args.module_root
|
||||
relative_module_root = module_root.relative_to(project_root)
|
||||
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
|
||||
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
|
||||
relative_benchmarks_root = (
|
||||
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
|
||||
original_git_root = git_root_dir()
|
||||
|
||||
# mirror project_root
|
||||
self.args.project_root = mirror_path(self.args.project_root, original_git_root, worktree_dir)
|
||||
self.test_cfg.project_root_path = mirror_path(self.test_cfg.project_root_path, original_git_root, worktree_dir)
|
||||
|
||||
# mirror module_root
|
||||
self.args.module_root = mirror_path(self.args.module_root, original_git_root, worktree_dir)
|
||||
|
||||
# mirror target file
|
||||
if self.args.file:
|
||||
self.args.file = mirror_path(self.args.file, original_git_root, worktree_dir)
|
||||
|
||||
# mirror tests root
|
||||
self.args.tests_root = mirror_path(self.args.tests_root, original_git_root, worktree_dir)
|
||||
self.test_cfg.tests_root = mirror_path(self.test_cfg.tests_root, original_git_root, worktree_dir)
|
||||
|
||||
# mirror tests project root
|
||||
self.args.test_project_root = mirror_path(self.args.test_project_root, original_git_root, worktree_dir)
|
||||
self.test_cfg.tests_project_rootdir = mirror_path(
|
||||
self.test_cfg.tests_project_rootdir, original_git_root, worktree_dir
|
||||
)
|
||||
|
||||
self.args.module_root = worktree_dir / relative_module_root
|
||||
self.args.project_root = worktree_dir
|
||||
self.args.test_project_root = worktree_dir
|
||||
self.args.tests_root = worktree_dir / relative_tests_root
|
||||
if relative_benchmarks_root:
|
||||
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
|
||||
# mirror benchmarks root paths
|
||||
if self.args.benchmarks_root:
|
||||
self.args.benchmarks_root = mirror_path(self.args.benchmarks_root, original_git_root, worktree_dir)
|
||||
if self.test_cfg.benchmark_tests_root:
|
||||
self.test_cfg.benchmark_tests_root = mirror_path(
|
||||
self.test_cfg.benchmark_tests_root, original_git_root, worktree_dir
|
||||
)
|
||||
|
||||
self.test_cfg.project_root_path = worktree_dir
|
||||
self.test_cfg.tests_project_rootdir = worktree_dir
|
||||
self.test_cfg.tests_root = worktree_dir / relative_tests_root
|
||||
if relative_benchmarks_root:
|
||||
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
|
||||
|
||||
if relative_optimized_file is not None:
|
||||
self.args.file = worktree_dir / relative_optimized_file
|
||||
def mirror_path(path: Path, src_root: Path, dest_root: Path) -> Path:
|
||||
relative_path = path.relative_to(src_root)
|
||||
return dest_root / relative_path
|
||||
|
||||
|
||||
def run_with_args(args: Namespace) -> None:
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def existing_tests_source_for(
|
|||
):
|
||||
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
|
||||
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
|
||||
print_filename = filename.relative_to(tests_root)
|
||||
print_filename = filename.resolve().relative_to(tests_root.resolve()).as_posix()
|
||||
greater = (
|
||||
optimized_tests_to_runtimes[filename][qualified_name]
|
||||
> original_tests_to_runtimes[filename][qualified_name]
|
||||
|
|
@ -192,9 +192,9 @@ def check_create_pr(
|
|||
if pr_number is not None:
|
||||
logger.info(f"Suggesting changes to PR #{pr_number} ...")
|
||||
owner, repo = get_repo_owner_and_name(git_repo)
|
||||
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
|
||||
relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix()
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
|
||||
Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
)
|
||||
for p in original_code
|
||||
|
|
@ -243,10 +243,10 @@ def check_create_pr(
|
|||
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
|
||||
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
|
||||
return
|
||||
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
|
||||
relative_path = explanation.file_path.resolve().relative_to(root_dir.resolve()).as_posix()
|
||||
base_branch = get_current_branch()
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
|
||||
Path(p).resolve().relative_to(root_dir.resolve()).as_posix(): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
)
|
||||
for p in original_code
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ from codeflash.code_utils.config_consts import (
|
|||
COVERAGE_THRESHOLD,
|
||||
MIN_IMPROVEMENT_THRESHOLD,
|
||||
MIN_TESTCASE_PASSED_THRESHOLD,
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
|
||||
)
|
||||
from codeflash.models.test_type import TestType
|
||||
from codeflash.models import models
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
|
|
@ -25,20 +26,41 @@ def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) ->
|
|||
return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns
|
||||
|
||||
|
||||
def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> float:
|
||||
"""Calculate the throughput gain of an optimized code over the original code.
|
||||
|
||||
This value multiplied by 100 gives the percentage improvement in throughput.
|
||||
For throughput, higher values are better (more executions per time period).
|
||||
"""
|
||||
if original_throughput == 0:
|
||||
return 0.0
|
||||
return (optimized_throughput - original_throughput) / original_throughput
|
||||
|
||||
|
||||
def speedup_critic(
|
||||
candidate_result: OptimizedCandidateResult,
|
||||
original_code_runtime: int,
|
||||
best_runtime_until_now: int | None,
|
||||
*,
|
||||
disable_gh_action_noise: bool = False,
|
||||
original_async_throughput: int | None = None,
|
||||
best_throughput_until_now: int | None = None,
|
||||
) -> bool:
|
||||
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
|
||||
|
||||
Ensure that the optimization is actually faster than the original code, above the noise floor.
|
||||
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
|
||||
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
|
||||
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
|
||||
Evaluates both runtime performance and async throughput improvements.
|
||||
|
||||
For runtime performance:
|
||||
- Ensures the optimization is actually faster than the original code, above the noise floor.
|
||||
- The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
|
||||
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
|
||||
- The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance.
|
||||
|
||||
For async throughput (when available):
|
||||
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
- Throughput improvements complement runtime improvements for async functions
|
||||
"""
|
||||
# Runtime performance evaluation
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
if not disable_gh_action_noise and env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
|
||||
|
|
@ -46,10 +68,31 @@ def speedup_critic(
|
|||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
|
||||
)
|
||||
if best_runtime_until_now is None:
|
||||
# collect all optimizations with this
|
||||
return bool(perf_gain > noise_floor)
|
||||
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
|
||||
runtime_improved = perf_gain > noise_floor
|
||||
|
||||
# Check runtime comparison with best so far
|
||||
runtime_is_best = best_runtime_until_now is None or candidate_result.best_test_runtime < best_runtime_until_now
|
||||
|
||||
throughput_improved = True # Default to True if no throughput data
|
||||
throughput_is_best = True # Default to True if no throughput data
|
||||
|
||||
if original_async_throughput is not None and candidate_result.async_throughput is not None:
|
||||
if original_async_throughput > 0:
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_async_throughput, optimized_throughput=candidate_result.async_throughput
|
||||
)
|
||||
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
|
||||
throughput_is_best = (
|
||||
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
|
||||
)
|
||||
|
||||
if original_async_throughput is not None and candidate_result.async_throughput is not None:
|
||||
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
|
||||
throughput_acceptance = throughput_improved and throughput_is_best
|
||||
runtime_acceptance = runtime_improved and runtime_is_best
|
||||
return throughput_acceptance or runtime_acceptance
|
||||
return runtime_improved and runtime_is_best
|
||||
|
||||
|
||||
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
|
||||
|
|
@ -63,7 +106,7 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
|
|||
if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD:
|
||||
return True
|
||||
# If one or more tests passed, check if least one of them was a successful REPLAY_TEST
|
||||
return bool(pass_count >= 1 and report[TestType.REPLAY_TEST]["passed"] >= 1)
|
||||
return bool(pass_count >= 1 and report[models.TestType.REPLAY_TEST]["passed"] >= 1) # type: ignore # noqa: PGH003
|
||||
|
||||
|
||||
def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from rich.table import Table
|
|||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import BenchmarkDetail, TestResults
|
||||
from codeflash.result.critic import throughput_gain
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
|
|
@ -24,9 +25,28 @@ class Explanation:
|
|||
function_name: str
|
||||
file_path: Path
|
||||
benchmark_details: Optional[list[BenchmarkDetail]] = None
|
||||
original_async_throughput: Optional[int] = None
|
||||
best_async_throughput: Optional[int] = None
|
||||
|
||||
@property
|
||||
def perf_improvement_line(self) -> str:
|
||||
runtime_improvement = self.speedup
|
||||
|
||||
if (
|
||||
self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
and self.original_async_throughput > 0
|
||||
):
|
||||
throughput_improvement = throughput_gain(
|
||||
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
|
||||
)
|
||||
|
||||
# Use throughput metrics if throughput improvement is better or runtime got worse
|
||||
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
|
||||
throughput_pct = f"{throughput_improvement * 100:,.0f}%"
|
||||
throughput_x = f"{throughput_improvement + 1:,.2f}x"
|
||||
return f"{throughput_pct} improvement ({throughput_x} faster)."
|
||||
|
||||
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
|
||||
|
||||
@property
|
||||
|
|
@ -46,6 +66,23 @@ class Explanation:
|
|||
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
|
||||
original_runtime_human = humanize_runtime(self.original_runtime_ns)
|
||||
best_runtime_human = humanize_runtime(self.best_runtime_ns)
|
||||
|
||||
# Determine if we're showing throughput or runtime improvements
|
||||
runtime_improvement = self.speedup
|
||||
is_using_throughput_metric = False
|
||||
|
||||
if (
|
||||
self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
and self.original_async_throughput > 0
|
||||
):
|
||||
throughput_improvement = throughput_gain(
|
||||
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
|
||||
)
|
||||
|
||||
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
|
||||
is_using_throughput_metric = True
|
||||
|
||||
benchmark_info = ""
|
||||
|
||||
if self.benchmark_details:
|
||||
|
|
@ -86,13 +123,18 @@ class Explanation:
|
|||
console.print(table)
|
||||
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
|
||||
|
||||
test_report = self.winning_behavior_test_results.get_test_pass_fail_report_by_type()
|
||||
test_report_str = TestResults.report_to_string(test_report)
|
||||
if is_using_throughput_metric:
|
||||
performance_description = (
|
||||
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
|
||||
f"(runtime: {original_runtime_human} → {best_runtime_human})\n\n"
|
||||
)
|
||||
else:
|
||||
performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
|
||||
|
||||
return (
|
||||
f"Optimized {self.function_name} in {self.file_path}\n"
|
||||
f"{self.perf_improvement_line}\n"
|
||||
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
|
||||
+ performance_description
|
||||
+ (benchmark_info if benchmark_info else "")
|
||||
+ self.raw_explanation_message
|
||||
+ " \n\n"
|
||||
|
|
@ -101,7 +143,7 @@ class Explanation:
|
|||
""
|
||||
if is_LSP_enabled()
|
||||
else "The new optimized code was tested for correctness. The results are listed below.\n"
|
||||
+ test_report_str
|
||||
f"{TestResults.report_to_string(self.winning_behavior_test_results.get_test_pass_fail_report_by_type())}\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import sys
|
|||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
||||
|
||||
|
|
@ -47,6 +48,17 @@ class FakeFrame:
|
|||
self.f_locals: dict = {}
|
||||
|
||||
|
||||
def patch_ap_scheduler() -> None:
|
||||
if find_spec("apscheduler"):
|
||||
import apscheduler.schedulers.background as bg
|
||||
import apscheduler.schedulers.blocking as bb
|
||||
from apscheduler.schedulers import base
|
||||
|
||||
bg.BackgroundScheduler.start = lambda _, *_a, **_k: None
|
||||
bb.BlockingScheduler.start = lambda _, *_a, **_k: None
|
||||
base.BaseScheduler.add_job = lambda _, *_a, **_k: None
|
||||
|
||||
|
||||
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
|
||||
class Tracer:
|
||||
"""Use this class as a 'with' context manager to trace a function call.
|
||||
|
|
@ -820,6 +832,7 @@ class Tracer:
|
|||
if __name__ == "__main__":
|
||||
args_dict = json.loads(sys.argv[-1])
|
||||
sys.argv = sys.argv[1:-1]
|
||||
patch_ap_scheduler()
|
||||
if args_dict["module"]:
|
||||
import runpy
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# ruff: noqa: PGH003
|
||||
import array
|
||||
import ast
|
||||
import datetime
|
||||
|
|
@ -8,6 +7,7 @@ import math
|
|||
import re
|
||||
import types
|
||||
from collections import ChainMap, OrderedDict, deque
|
||||
from importlib.util import find_spec
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
|
|
@ -15,51 +15,14 @@ import sentry_sdk
|
|||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
HAS_NUMPY = False
|
||||
try:
|
||||
import sqlalchemy # type: ignore
|
||||
|
||||
HAS_SQLALCHEMY = True
|
||||
except ImportError:
|
||||
HAS_SQLALCHEMY = False
|
||||
try:
|
||||
import scipy # type: ignore
|
||||
|
||||
HAS_SCIPY = True
|
||||
except ImportError:
|
||||
HAS_SCIPY = False
|
||||
|
||||
try:
|
||||
import pandas # type: ignore # noqa: ICN001
|
||||
|
||||
HAS_PANDAS = True
|
||||
except ImportError:
|
||||
HAS_PANDAS = False
|
||||
|
||||
try:
|
||||
import pyrsistent # type: ignore
|
||||
|
||||
HAS_PYRSISTENT = True
|
||||
except ImportError:
|
||||
HAS_PYRSISTENT = False
|
||||
try:
|
||||
import torch # type: ignore
|
||||
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
try:
|
||||
import jax # type: ignore
|
||||
import jax.numpy as jnp # type: ignore
|
||||
|
||||
HAS_JAX = True
|
||||
except ImportError:
|
||||
HAS_JAX = False
|
||||
HAS_NUMPY = find_spec("numpy") is not None
|
||||
HAS_SQLALCHEMY = find_spec("sqlalchemy") is not None
|
||||
HAS_SCIPY = find_spec("scipy") is not None
|
||||
HAS_PANDAS = find_spec("pandas") is not None
|
||||
HAS_PYRSISTENT = find_spec("pyrsistent") is not None
|
||||
HAS_TORCH = find_spec("torch") is not None
|
||||
HAS_JAX = find_spec("jax") is not None
|
||||
HAS_XARRAY = find_spec("xarray") is not None
|
||||
|
||||
|
||||
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
|
||||
|
|
@ -115,15 +78,28 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
|
||||
return comparator(orig_dict, new_dict, superset_obj)
|
||||
|
||||
# Handle JAX arrays first to avoid boolean context errors in other conditions
|
||||
if HAS_JAX and isinstance(orig, jax.Array):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
if orig.shape != new.shape:
|
||||
return False
|
||||
return bool(jnp.allclose(orig, new, equal_nan=True))
|
||||
if HAS_JAX:
|
||||
import jax # type: ignore # noqa: PGH003
|
||||
import jax.numpy as jnp # type: ignore # noqa: PGH003
|
||||
|
||||
# Handle JAX arrays first to avoid boolean context errors in other conditions
|
||||
if isinstance(orig, jax.Array):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
if orig.shape != new.shape:
|
||||
return False
|
||||
return bool(jnp.allclose(orig, new, equal_nan=True))
|
||||
|
||||
# Handle xarray objects before numpy to avoid boolean context errors
|
||||
if HAS_XARRAY:
|
||||
import xarray # type: ignore # noqa: PGH003
|
||||
|
||||
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
|
||||
return orig.identical(new)
|
||||
|
||||
if HAS_SQLALCHEMY:
|
||||
import sqlalchemy # type: ignore # noqa: PGH003
|
||||
|
||||
try:
|
||||
insp = sqlalchemy.inspection.inspect(orig)
|
||||
insp = sqlalchemy.inspection.inspect(new) # noqa: F841
|
||||
|
|
@ -138,6 +114,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
|
||||
except sqlalchemy.exc.NoInspectionAvailable:
|
||||
pass
|
||||
|
||||
if HAS_SCIPY:
|
||||
import scipy # type: ignore # noqa: PGH003
|
||||
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
|
||||
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
|
||||
if superset_obj:
|
||||
|
|
@ -151,27 +130,30 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
return False
|
||||
return True
|
||||
|
||||
if HAS_NUMPY and isinstance(orig, np.ndarray):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
if orig.shape != new.shape:
|
||||
return False
|
||||
try:
|
||||
return np.allclose(orig, new, equal_nan=True)
|
||||
except Exception:
|
||||
# fails at "ufunc 'isfinite' not supported for the input types"
|
||||
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
|
||||
if HAS_NUMPY:
|
||||
import numpy as np # type: ignore # noqa: PGH003
|
||||
|
||||
if HAS_NUMPY and isinstance(orig, (np.floating, np.complex64, np.complex128)):
|
||||
return np.isclose(orig, new)
|
||||
if isinstance(orig, np.ndarray):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
if orig.shape != new.shape:
|
||||
return False
|
||||
try:
|
||||
return np.allclose(orig, new, equal_nan=True)
|
||||
except Exception:
|
||||
# fails at "ufunc 'isfinite' not supported for the input types"
|
||||
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
|
||||
|
||||
if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)):
|
||||
return orig == new
|
||||
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
|
||||
return np.isclose(orig, new)
|
||||
|
||||
if HAS_NUMPY and isinstance(orig, np.void):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
|
||||
if isinstance(orig, (np.integer, np.bool_, np.byte)):
|
||||
return orig == new
|
||||
|
||||
if isinstance(orig, np.void):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields)
|
||||
|
||||
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
|
||||
if orig.dtype != new.dtype:
|
||||
|
|
@ -180,15 +162,18 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
return False
|
||||
return (orig != new).nnz == 0
|
||||
|
||||
if HAS_PANDAS and isinstance(
|
||||
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
|
||||
):
|
||||
return orig.equals(new)
|
||||
if HAS_PANDAS:
|
||||
import pandas # type: ignore # noqa: ICN001, PGH003
|
||||
|
||||
if HAS_PANDAS and isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
|
||||
return orig == new
|
||||
if HAS_PANDAS and pandas.isna(orig) and pandas.isna(new):
|
||||
return True
|
||||
if isinstance(
|
||||
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
|
||||
):
|
||||
return orig.equals(new)
|
||||
|
||||
if isinstance(orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)):
|
||||
return orig == new
|
||||
if pandas.isna(orig) and pandas.isna(new):
|
||||
return True
|
||||
|
||||
if isinstance(orig, array.array):
|
||||
if orig.typecode != new.typecode:
|
||||
|
|
@ -209,31 +194,58 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
if HAS_TORCH and isinstance(orig, torch.Tensor):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
if orig.shape != new.shape:
|
||||
return False
|
||||
if orig.requires_grad != new.requires_grad:
|
||||
return False
|
||||
if orig.device != new.device:
|
||||
return False
|
||||
return torch.allclose(orig, new, equal_nan=True)
|
||||
if HAS_TORCH:
|
||||
import torch # type: ignore # noqa: PGH003
|
||||
|
||||
if HAS_PYRSISTENT and isinstance(
|
||||
orig,
|
||||
(
|
||||
pyrsistent.PMap,
|
||||
pyrsistent.PVector,
|
||||
pyrsistent.PSet,
|
||||
pyrsistent.PRecord,
|
||||
pyrsistent.PClass,
|
||||
pyrsistent.PBag,
|
||||
pyrsistent.PList,
|
||||
pyrsistent.PDeque,
|
||||
),
|
||||
):
|
||||
return orig == new
|
||||
if isinstance(orig, torch.Tensor):
|
||||
if orig.dtype != new.dtype:
|
||||
return False
|
||||
if orig.shape != new.shape:
|
||||
return False
|
||||
if orig.requires_grad != new.requires_grad:
|
||||
return False
|
||||
if orig.device != new.device:
|
||||
return False
|
||||
return torch.allclose(orig, new, equal_nan=True)
|
||||
|
||||
if HAS_PYRSISTENT:
|
||||
import pyrsistent # type: ignore # noqa: PGH003
|
||||
|
||||
if isinstance(
|
||||
orig,
|
||||
(
|
||||
pyrsistent.PMap,
|
||||
pyrsistent.PVector,
|
||||
pyrsistent.PSet,
|
||||
pyrsistent.PRecord,
|
||||
pyrsistent.PClass,
|
||||
pyrsistent.PBag,
|
||||
pyrsistent.PList,
|
||||
pyrsistent.PDeque,
|
||||
),
|
||||
):
|
||||
return orig == new
|
||||
|
||||
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
|
||||
orig_dict = {}
|
||||
new_dict = {}
|
||||
|
||||
for attr in orig.__attrs_attrs__:
|
||||
if attr.eq:
|
||||
attr_name = attr.name
|
||||
orig_dict[attr_name] = getattr(orig, attr_name, None)
|
||||
new_dict[attr_name] = getattr(new, attr_name, None)
|
||||
|
||||
if superset_obj:
|
||||
new_attrs_dict = {}
|
||||
for attr in new.__attrs_attrs__:
|
||||
if attr.eq:
|
||||
attr_name = attr.name
|
||||
new_attrs_dict[attr_name] = getattr(new, attr_name, None)
|
||||
return all(
|
||||
k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items()
|
||||
)
|
||||
return comparator(orig_dict, new_dict, superset_obj)
|
||||
|
||||
# re.Pattern can be made better by DFA Minimization and then comparing
|
||||
if isinstance(
|
||||
|
|
|
|||
|
|
@ -38,20 +38,20 @@ class CoverageUtils:
|
|||
|
||||
cov = Coverage(data_file=database_path, config_file=config_path, data_suffix=True, auto_data=True, branch=True)
|
||||
|
||||
if not database_path.stat().st_size or not database_path.exists():
|
||||
if not database_path.exists() or not database_path.stat().st_size:
|
||||
logger.debug(f"Coverage database {database_path} is empty or does not exist")
|
||||
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
|
||||
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
|
||||
return CoverageData.create_empty(source_code_path, function_name, code_context)
|
||||
cov.load()
|
||||
|
||||
reporter = JsonReporter(cov)
|
||||
temp_json_file = database_path.with_suffix(".report.json")
|
||||
with temp_json_file.open("w") as f:
|
||||
with temp_json_file.open("w", encoding="utf-8") as f:
|
||||
try:
|
||||
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
|
||||
except NoDataError:
|
||||
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
|
||||
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
|
||||
return CoverageData.create_empty(source_code_path, function_name, code_context)
|
||||
with temp_json_file.open() as f:
|
||||
original_coverage_data = json.load(f)
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ class CoverageUtils:
|
|||
def _parse_coverage_file(
|
||||
coverage_file_path: Path, source_code_path: Path
|
||||
) -> tuple[dict[str, dict[str, Any]], CoverageStatus]:
|
||||
with coverage_file_path.open() as f:
|
||||
with coverage_file_path.open(encoding="utf-8") as f:
|
||||
coverage_data = json.load(f)
|
||||
|
||||
candidates = generate_candidates(source_code_path)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def instrument_codeflash_capture(
|
|||
modified_code = add_codeflash_capture_to_init(
|
||||
target_classes={class_parent.name},
|
||||
fto_name=function_to_optimize.function_name,
|
||||
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
code=original_code,
|
||||
tests_root=tests_root,
|
||||
is_fto=True,
|
||||
|
|
@ -46,7 +46,7 @@ def instrument_codeflash_capture(
|
|||
modified_code = add_codeflash_capture_to_init(
|
||||
target_classes=helper_classes,
|
||||
fto_name=function_to_optimize.function_name,
|
||||
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
code=original_code,
|
||||
tests_root=tests_root,
|
||||
is_fto=False,
|
||||
|
|
@ -92,6 +92,27 @@ class InitDecorator(ast.NodeTransformer):
|
|||
self.tests_root = tests_root
|
||||
self.inserted_decorator = False
|
||||
|
||||
# Precompute decorator components to avoid reconstructing on every node visit
|
||||
# Only the `function_name` field changes per class
|
||||
self._base_decorator_keywords = [
|
||||
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
|
||||
ast.keyword(arg="tests_root", value=ast.Constant(value=self.tests_root.as_posix())),
|
||||
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
|
||||
]
|
||||
self._base_decorator_func = ast.Name(id="codeflash_capture", ctx=ast.Load())
|
||||
|
||||
# Preconstruct starred/kwargs for super init injection for perf
|
||||
self._super_starred = ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))
|
||||
self._super_kwarg = ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))
|
||||
self._super_func = ast.Attribute(
|
||||
value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]),
|
||||
attr="__init__",
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
self._init_vararg = ast.arg(arg="args")
|
||||
self._init_kwarg = ast.arg(arg="kwargs")
|
||||
self._init_self_arg = ast.arg(arg="self", annotation=None)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
|
||||
# Check if our import already exists
|
||||
if node.module == "codeflash.verification.codeflash_capture" and any(
|
||||
|
|
@ -114,21 +135,18 @@ class InitDecorator(ast.NodeTransformer):
|
|||
if node.name not in self.target_classes:
|
||||
return node
|
||||
|
||||
# Look for __init__ method
|
||||
has_init = False
|
||||
|
||||
# Create the decorator
|
||||
# Build decorator node ONCE for each class, not per loop iteration
|
||||
decorator = ast.Call(
|
||||
func=ast.Name(id="codeflash_capture", ctx=ast.Load()),
|
||||
func=self._base_decorator_func,
|
||||
args=[],
|
||||
keywords=[
|
||||
ast.keyword(arg="function_name", value=ast.Constant(value=f"{node.name}.__init__")),
|
||||
ast.keyword(arg="tmp_dir_path", value=ast.Constant(value=self.tmp_dir_path)),
|
||||
ast.keyword(arg="tests_root", value=ast.Constant(value=str(self.tests_root))),
|
||||
ast.keyword(arg="is_fto", value=ast.Constant(value=self.is_fto)),
|
||||
*self._base_decorator_keywords,
|
||||
],
|
||||
)
|
||||
|
||||
# Only scan node.body once for both __init__ and decorator check
|
||||
for item in node.body:
|
||||
if (
|
||||
isinstance(item, ast.FunctionDef)
|
||||
|
|
@ -139,35 +157,28 @@ class InitDecorator(ast.NodeTransformer):
|
|||
):
|
||||
has_init = True
|
||||
|
||||
# Add decorator at the start of the list if not already present
|
||||
if not any(
|
||||
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture"
|
||||
for d in item.decorator_list
|
||||
):
|
||||
# Check for existing decorator in-place, stop after finding one
|
||||
for d in item.decorator_list:
|
||||
if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture":
|
||||
break
|
||||
else:
|
||||
# No decorator found
|
||||
item.decorator_list.insert(0, decorator)
|
||||
self.inserted_decorator = True
|
||||
|
||||
if not has_init:
|
||||
# Create super().__init__(*args, **kwargs) call
|
||||
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
|
||||
super_call = ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Call(func=ast.Name(id="super", ctx=ast.Load()), args=[], keywords=[]),
|
||||
attr="__init__",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()))],
|
||||
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
|
||||
)
|
||||
value=ast.Call(func=self._super_func, args=[self._super_starred], keywords=[self._super_kwarg])
|
||||
)
|
||||
# Create function arguments: self, *args, **kwargs
|
||||
# Create function arguments: self, *args, **kwargs (reuse arg nodes)
|
||||
arguments = ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[ast.arg(arg="self", annotation=None)],
|
||||
vararg=ast.arg(arg="args"),
|
||||
args=[self._init_self_arg],
|
||||
vararg=self._init_vararg,
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
kwarg=ast.arg(arg="kwargs"),
|
||||
kwarg=self._init_kwarg,
|
||||
defaults=[],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,30 @@ matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)
|
|||
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
|
||||
|
||||
|
||||
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
|
||||
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
|
||||
|
||||
|
||||
def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int:
|
||||
"""Calculate function throughput from TestResults by extracting performance stdout.
|
||||
|
||||
A completed execution is defined as having both a start tag and matching end tag from performance wrappers.
|
||||
Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$!
|
||||
End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######!
|
||||
"""
|
||||
start_matches = start_pattern.findall(test_results.perf_stdout or "")
|
||||
end_matches = end_pattern.findall(test_results.perf_stdout or "")
|
||||
|
||||
end_matches_truncated = [end_match[:5] for end_match in end_matches]
|
||||
end_matches_set = set(end_matches_truncated)
|
||||
|
||||
function_throughput = 0
|
||||
for start_match in start_matches:
|
||||
if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name:
|
||||
function_throughput += 1
|
||||
return function_throughput
|
||||
|
||||
|
||||
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
||||
test_results = TestResults()
|
||||
if not file_location.exists():
|
||||
|
|
|
|||
|
|
@ -450,3 +450,26 @@ class PytestLoops:
|
|||
metafunc.parametrize(
|
||||
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
|
||||
)
|
||||
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_runtest_setup(self, item: pytest.Item) -> None:
|
||||
"""Set test context environment variables before each test."""
|
||||
test_module_name = item.module.__name__ if item.module else "unknown_module"
|
||||
|
||||
test_class_name = None
|
||||
if item.cls:
|
||||
test_class_name = item.cls.__name__
|
||||
|
||||
test_function_name = item.name
|
||||
if "[" in test_function_name:
|
||||
test_function_name = test_function_name.split("[", 1)[0]
|
||||
|
||||
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
|
||||
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
|
||||
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
|
||||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
|
||||
"""Clean up test context environment variables after each test."""
|
||||
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
|
||||
os.environ.pop(var, None)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.17.0"
|
||||
__version__ = "0.17.2"
|
||||
|
|
|
|||
|
|
@ -52,8 +52,14 @@ Homepage = "https://codeflash.ai"
|
|||
[project.scripts]
|
||||
codeflash = "codeflash.main:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
asyncio = [
|
||||
"pytest-asyncio>=1.2.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
{include-group = "asyncio"},
|
||||
"ipython>=8.12.0",
|
||||
"mypy>=1.13",
|
||||
"ruff>=0.7.0",
|
||||
|
|
@ -76,6 +82,9 @@ dev = [
|
|||
"uv>=0.6.2",
|
||||
"pre-commit>=4.2.0,<5",
|
||||
]
|
||||
asyncio = [
|
||||
"pytest-asyncio>=1.2.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["codeflash"]
|
||||
|
|
|
|||
28
tests/scripts/end_to_end_test_async.py
Normal file
28
tests/scripts/end_to_end_test_async.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
|
||||
|
||||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="main.py",
|
||||
expected_unit_tests=0,
|
||||
min_improvement_x=0.1,
|
||||
enable_async=True,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
expected_coverage=100.0,
|
||||
expected_lines=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
)
|
||||
],
|
||||
)
|
||||
cwd = (
|
||||
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "async_e2e"
|
||||
).resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))
|
||||
|
|
@ -11,7 +11,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
function_name="sorter",
|
||||
benchmarks_root=cwd / "tests" / "pytest" / "benchmarks",
|
||||
test_framework="pytest",
|
||||
min_improvement_x=1.0,
|
||||
min_improvement_x=0.70,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
file_path="bubble_sort.py",
|
||||
function_name="sorter",
|
||||
test_framework="pytest",
|
||||
min_improvement_x=1.0,
|
||||
min_improvement_x=0.70,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_wit
|
|||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=3.0
|
||||
file_path="bubble_sort.py", function_name="sorter", test_framework="unittest", min_improvement_x=0.30
|
||||
)
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class TestConfig:
|
|||
trace_mode: bool = False
|
||||
coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
|
||||
benchmarks_root: Optional[pathlib.Path] = None
|
||||
enable_async: bool = False
|
||||
|
||||
|
||||
def clear_directory(directory_path: str | pathlib.Path) -> None:
|
||||
|
|
@ -88,8 +89,10 @@ def run_codeflash_command(
|
|||
test_root = cwd / "tests" / (config.test_framework or "")
|
||||
|
||||
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
|
||||
env = os.environ.copy()
|
||||
env['PYTHONIOENCODING'] = 'utf-8'
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
|
||||
)
|
||||
|
||||
output = []
|
||||
|
|
@ -122,7 +125,7 @@ def build_command(
|
|||
) -> list[str]:
|
||||
python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py"
|
||||
|
||||
base_command = ["python", python_path, "--file", config.file_path, "--no-pr"]
|
||||
base_command = ["uv", "run", "--no-project", python_path, "--file", config.file_path, "--no-pr"]
|
||||
|
||||
if config.function_name:
|
||||
base_command.extend(["--function", config.function_name])
|
||||
|
|
@ -132,6 +135,8 @@ def build_command(
|
|||
)
|
||||
if benchmarks_root:
|
||||
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
|
||||
if config.enable_async:
|
||||
base_command.append("--async")
|
||||
return base_command
|
||||
|
||||
|
||||
|
|
@ -187,9 +192,11 @@ def validate_stdout_in_candidate(stdout: str, expected_in_stdout: list[str]) ->
|
|||
def run_trace_test(cwd: pathlib.Path, config: TestConfig, expected_improvement_pct: int) -> bool:
|
||||
test_root = cwd / "tests" / (config.test_framework or "")
|
||||
clear_directory(test_root)
|
||||
command = ["python", "-m", "codeflash.main", "optimize", "workload.py"]
|
||||
command = ["uv", "run", "--no-project", "-m", "codeflash.main", "optimize", "workload.py"]
|
||||
env = os.environ.copy()
|
||||
env['PYTHONIOENCODING'] = 'utf-8'
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=env, encoding='utf-8'
|
||||
)
|
||||
|
||||
output = []
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ from pathlib import Path
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
|
||||
import tempfile
|
||||
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
|
||||
import libcst as cst
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
def test_add_needed_imports_from_module0() -> None:
|
||||
src_module = '''import ast
|
||||
|
|
@ -349,3 +353,141 @@ class DbtAdapter(BaseAdapter):
|
|||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
|
||||
|
||||
def test_resolve_star_import_with_all_defined():
|
||||
"""Test resolve_star_import when __all__ is explicitly defined."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
# Create a test module with __all__ definition
|
||||
test_module.write_text('''
|
||||
__all__ = ['public_function', 'PublicClass']
|
||||
|
||||
def public_function():
|
||||
pass
|
||||
|
||||
def _private_function():
|
||||
pass
|
||||
|
||||
class PublicClass:
|
||||
pass
|
||||
|
||||
class AnotherPublicClass:
|
||||
"""Not in __all__ so should be excluded."""
|
||||
pass
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_function', 'PublicClass'}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
def test_resolve_star_import_without_all_defined():
|
||||
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
test_module = project_root / 'test_module.py'
|
||||
|
||||
# Create a test module without __all__ definition
|
||||
test_module.write_text('''
|
||||
def public_func():
|
||||
pass
|
||||
|
||||
def _private_func():
|
||||
pass
|
||||
|
||||
class PublicClass:
|
||||
pass
|
||||
|
||||
PUBLIC_VAR = 42
|
||||
_private_var = 'secret'
|
||||
''')
|
||||
|
||||
symbols = resolve_star_import('test_module', project_root)
|
||||
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
|
||||
assert symbols == expected_symbols
|
||||
|
||||
|
||||
def test_resolve_star_import_nonexistent_module():
|
||||
"""Test resolve_star_import with non-existent module - should return empty set."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
symbols = resolve_star_import('nonexistent_module', project_root)
|
||||
assert symbols == set()
|
||||
|
||||
|
||||
def test_dotted_import_collector_skips_star_imports():
|
||||
"""Test that DottedImportCollector correctly skips star imports."""
|
||||
code_with_star_import = '''
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import os
|
||||
'''
|
||||
|
||||
module = cst.parse_module(code_with_star_import)
|
||||
collector = DottedImportCollector()
|
||||
module.visit(collector)
|
||||
|
||||
# Should collect regular imports but skip the star import
|
||||
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
|
||||
assert collector.imports == expected_imports
|
||||
|
||||
|
||||
def test_add_needed_imports_with_star_import_resolution():
|
||||
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
# Create a source module that exports symbols
|
||||
src_module = project_root / 'source_module.py'
|
||||
src_module.write_text('''
|
||||
__all__ = ['UtilFunction', 'HelperClass']
|
||||
|
||||
def UtilFunction():
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
pass
|
||||
''')
|
||||
|
||||
# Create source code that uses star import
|
||||
src_code = '''
|
||||
from source_module import *
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
# Destination code that needs the imports resolved
|
||||
dst_code = '''
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
|
||||
src_path = project_root / 'src.py'
|
||||
dst_path = project_root / 'dst.py'
|
||||
src_path.write_text(src_code)
|
||||
|
||||
result = add_needed_imports_from_module(
|
||||
src_code, dst_code, src_path, dst_path, project_root
|
||||
)
|
||||
|
||||
# The result should have individual imports instead of star import
|
||||
expected_result = '''from source_module import HelperClass, UtilFunction
|
||||
|
||||
def my_function():
|
||||
helper = HelperClass()
|
||||
UtilFunction()
|
||||
return helper
|
||||
'''
|
||||
assert result == expected_result
|
||||
|
|
|
|||
|
|
@ -1902,4 +1902,210 @@ def test_bubble_sort(input, expected_output):
|
|||
|
||||
# Check that comments were added
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert modified_source == expected
|
||||
assert modified_source == expected
|
||||
|
||||
def test_async_basic_runtime_comment_addition(self, test_config):
|
||||
"""Test basic functionality of adding runtime comments to async test functions."""
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_async_bubble_sort():
|
||||
codeflash_output = await async_bubble_sort([3, 1, 2])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
"""
|
||||
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py",
|
||||
)
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_invocation = self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0') # 500μs
|
||||
optimized_invocation = self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0') # 300μs
|
||||
|
||||
original_test_results.add(original_invocation)
|
||||
optimized_test_results.add(optimized_invocation)
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert "# 500μs -> 300μs" in modified_source
|
||||
assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
|
||||
|
||||
def test_async_multiple_test_functions(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_async_bubble_sort():
|
||||
codeflash_output = await async_quick_sort([3, 1, 2])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
|
||||
async def test_async_quick_sort():
|
||||
codeflash_output = await async_quick_sort([5, 2, 8])
|
||||
assert codeflash_output == [2, 5, 8]
|
||||
|
||||
def helper_function():
|
||||
return "not a test"
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_test_results.add(self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_async_quick_sort", 800_000, iteration_id='0'))
|
||||
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_quick_sort", 600_000, iteration_id='0'))
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
|
||||
assert "# 500μs -> 300μs" in modified_source
|
||||
assert "# 800μs -> 600μs" in modified_source
|
||||
assert (
|
||||
"helper_function():" in modified_source
|
||||
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
|
||||
)
|
||||
|
||||
def test_async_class_method(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = '''class TestAsyncClass:
|
||||
async def test_async_function(self):
|
||||
codeflash_output = await some_async_function()
|
||||
assert codeflash_output == expected
|
||||
'''
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
invocation_id = InvocationId(
|
||||
test_module_path="tests.test_module__unit_test_0",
|
||||
test_class_name="TestAsyncClass",
|
||||
test_function_name="test_async_function",
|
||||
function_getting_tested="some_async_function",
|
||||
iteration_id="0",
|
||||
)
|
||||
|
||||
original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds
|
||||
optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
expected_source = '''class TestAsyncClass:
|
||||
async def test_async_function(self):
|
||||
codeflash_output = await some_async_function() # 2.00s -> 1.00s (100% faster)
|
||||
assert codeflash_output == expected
|
||||
'''
|
||||
|
||||
assert len(result.generated_tests) == 1
|
||||
assert result.generated_tests[0].generated_original_test_source == expected_source
|
||||
|
||||
def test_async_mixed_sync_and_async_functions(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """def test_sync_function():
|
||||
codeflash_output = sync_function([1, 2, 3])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
|
||||
async def test_async_function():
|
||||
codeflash_output = await async_function([4, 5, 6])
|
||||
assert codeflash_output == [4, 5, 6]
|
||||
|
||||
def test_another_sync():
|
||||
result = another_sync_func()
|
||||
assert result is True
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
# Add test invocations for all test functions
|
||||
original_test_results.add(self.create_test_invocation("test_sync_function", 400_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_async_function", 600_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_another_sync", 200_000, iteration_id='0'))
|
||||
|
||||
optimized_test_results.add(self.create_test_invocation("test_sync_function", 200_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_function", 300_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_another_sync", 100_000, iteration_id='0'))
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
|
||||
assert "# 400μs -> 200μs" in modified_source
|
||||
assert "# 600μs -> 300μs" in modified_source
|
||||
assert "# 200μs -> 100μs" in modified_source
|
||||
|
||||
assert "async def test_async_function():" in modified_source
|
||||
assert "await async_function([4, 5, 6])" in modified_source
|
||||
|
||||
def test_async_complex_await_patterns(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_complex_async():
|
||||
# Multiple await calls
|
||||
result1 = await async_func1()
|
||||
codeflash_output = await async_func2(result1)
|
||||
result3 = await async_func3(codeflash_output)
|
||||
assert result3 == expected
|
||||
|
||||
# Await in context manager
|
||||
async with async_context() as ctx:
|
||||
final_result = await ctx.process()
|
||||
assert final_result is not None
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_test_results.add(self.create_test_invocation("test_complex_async", 750_000, iteration_id='1')) # 750μs
|
||||
optimized_test_results.add(self.create_test_invocation("test_complex_async", 450_000, iteration_id='1')) # 450μs
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert "# 750μs -> 450μs" in modified_source
|
||||
347
tests/test_async_function_discovery.py
Normal file
347
tests/test_async_function_discovery.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
find_all_functions_in_file,
|
||||
get_functions_to_optimize,
|
||||
inspect_top_level_functions_or_methods,
|
||||
)
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_function_detection(temp_dir):
|
||||
async_function = """
|
||||
async def async_function_with_return():
|
||||
await some_async_operation()
|
||||
return 42
|
||||
|
||||
async def async_function_without_return():
|
||||
await some_async_operation()
|
||||
print("No return")
|
||||
|
||||
def regular_function():
|
||||
return 10
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_function)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_function_with_return" in function_names
|
||||
assert "regular_function" in function_names
|
||||
assert "async_function_without_return" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_method_in_class(temp_dir):
|
||||
code_with_async_method = """
|
||||
class AsyncClass:
|
||||
async def async_method(self):
|
||||
await self.do_something()
|
||||
return "result"
|
||||
|
||||
async def async_method_no_return(self):
|
||||
await self.do_something()
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
return "sync result"
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code_with_async_method)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
function_names = [fn.function_name for fn in found_functions]
|
||||
qualified_names = [fn.qualified_name for fn in found_functions]
|
||||
|
||||
assert "async_method" in function_names
|
||||
assert "AsyncClass.async_method" in qualified_names
|
||||
|
||||
assert "sync_method" in function_names
|
||||
assert "AsyncClass.sync_method" in qualified_names
|
||||
|
||||
assert "async_method_no_return" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_nested_async_functions(temp_dir):
|
||||
nested_async = """
|
||||
async def outer_async():
|
||||
async def inner_async():
|
||||
return "inner"
|
||||
|
||||
result = await inner_async()
|
||||
return result
|
||||
|
||||
def outer_sync():
|
||||
async def inner_async():
|
||||
return "inner from sync"
|
||||
|
||||
return inner_async
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(nested_async)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "outer_async" in function_names
|
||||
assert "outer_sync" in function_names
|
||||
assert "inner_async" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_staticmethod_and_classmethod(temp_dir):
|
||||
async_decorators = """
|
||||
class MyClass:
|
||||
@staticmethod
|
||||
async def async_static_method():
|
||||
await some_operation()
|
||||
return "static result"
|
||||
|
||||
@classmethod
|
||||
async def async_class_method(cls):
|
||||
await cls.some_operation()
|
||||
return "class result"
|
||||
|
||||
@property
|
||||
async def async_property(self):
|
||||
return await self.get_value()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_decorators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_static_method" in function_names
|
||||
assert "async_class_method" in function_names
|
||||
|
||||
assert "async_property" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_generator_functions(temp_dir):
|
||||
async_generators = """
|
||||
async def async_generator_with_return():
|
||||
for i in range(10):
|
||||
yield i
|
||||
return "done"
|
||||
|
||||
async def async_generator_no_return():
|
||||
for i in range(10):
|
||||
yield i
|
||||
|
||||
async def regular_async_with_return():
|
||||
result = await compute()
|
||||
return result
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(async_generators)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
function_names = [fn.function_name for fn in functions_found[file_path]]
|
||||
|
||||
assert "async_generator_with_return" in function_names
|
||||
assert "regular_async_with_return" in function_names
|
||||
assert "async_generator_no_return" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inspect_async_top_level_functions(temp_dir):
|
||||
code = """
|
||||
async def top_level_async():
|
||||
return 42
|
||||
|
||||
class AsyncContainer:
|
||||
async def async_method(self):
|
||||
async def nested_async():
|
||||
return 1
|
||||
return await nested_async()
|
||||
|
||||
@staticmethod
|
||||
async def async_static():
|
||||
return "static"
|
||||
|
||||
@classmethod
|
||||
async def async_classmethod(cls):
|
||||
return "classmethod"
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(code)
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
|
||||
assert result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
|
||||
assert not result.is_top_level
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_staticmethod
|
||||
|
||||
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
|
||||
assert result.is_top_level
|
||||
assert result.is_classmethod
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_get_functions_to_optimize_with_async(temp_dir):
|
||||
mixed_code = """
|
||||
async def async_func_one():
|
||||
return await operation_one()
|
||||
|
||||
def sync_func_one():
|
||||
return operation_one()
|
||||
|
||||
async def async_func_two():
|
||||
print("no return")
|
||||
|
||||
class MixedClass:
|
||||
async def async_method(self):
|
||||
return await self.operation()
|
||||
|
||||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=file_path,
|
||||
only_get_this_function=None,
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[],
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
enable_async=True,
|
||||
)
|
||||
|
||||
assert functions_count == 4
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "async_func_one" in function_names
|
||||
assert "sync_func_one" in function_names
|
||||
assert "async_method" in function_names
|
||||
assert "sync_method" in function_names
|
||||
|
||||
assert "async_func_two" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_no_async_functions_finding(temp_dir):
|
||||
mixed_code = """
|
||||
async def async_func_one():
|
||||
return await operation_one()
|
||||
|
||||
def sync_func_one():
|
||||
return operation_one()
|
||||
|
||||
async def async_func_two():
|
||||
print("no return")
|
||||
|
||||
class MixedClass:
|
||||
async def async_method(self):
|
||||
return await self.operation()
|
||||
|
||||
def sync_method(self):
|
||||
return self.operation()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(mixed_code)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests",
|
||||
project_root_path=".",
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=Path()
|
||||
)
|
||||
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=file_path,
|
||||
only_get_this_function=None,
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[],
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
enable_async=False,
|
||||
)
|
||||
|
||||
assert functions_count == 2
|
||||
|
||||
function_names = [fn.function_name for fn in functions[file_path]]
|
||||
assert "sync_func_one" in function_names
|
||||
assert "sync_method" in function_names
|
||||
assert "async_func_one" not in function_names
|
||||
assert "async_method" not in function_names
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_function_parents(temp_dir):
|
||||
complex_structure = """
|
||||
class OuterClass:
|
||||
async def outer_method(self):
|
||||
return 1
|
||||
|
||||
class InnerClass:
|
||||
async def inner_method(self):
|
||||
return 2
|
||||
|
||||
async def module_level_async():
|
||||
class LocalClass:
|
||||
async def local_method(self):
|
||||
return 3
|
||||
return LocalClass()
|
||||
"""
|
||||
|
||||
file_path = temp_dir / "test_file.py"
|
||||
file_path.write_text(complex_structure)
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
|
||||
found_functions = functions_found[file_path]
|
||||
|
||||
for fn in found_functions:
|
||||
if fn.function_name == "outer_method":
|
||||
assert len(fn.parents) == 1
|
||||
assert fn.parents[0].name == "OuterClass"
|
||||
assert fn.qualified_name == "OuterClass.outer_method"
|
||||
elif fn.function_name == "inner_method":
|
||||
assert len(fn.parents) == 2
|
||||
assert fn.parents[0].name == "OuterClass"
|
||||
assert fn.parents[1].name == "InnerClass"
|
||||
elif fn.function_name == "module_level_async":
|
||||
assert len(fn.parents) == 0
|
||||
assert fn.qualified_name == "module_level_async"
|
||||
1050
tests/test_async_run_and_parse_tests.py
Normal file
1050
tests/test_async_run_and_parse_tests.py
Normal file
File diff suppressed because it is too large
Load diff
287
tests/test_async_wrapper_sqlite_validation.py
Normal file
287
tests/test_async_wrapper_sqlite_validation.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import (
|
||||
codeflash_behavior_async,
|
||||
codeflash_performance_async,
|
||||
)
|
||||
from codeflash.verification.codeflash_capture import VerificationType
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestAsyncWrapperSQLiteValidation:
|
||||
|
||||
@pytest.fixture
|
||||
def test_env_setup(self, request):
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "0",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestAsyncWrapperSQLiteValidation",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CURRENT_LINE_ID": "test_unit",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self, test_env_setup):
|
||||
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_{iteration}.sqlite"))
|
||||
|
||||
yield db_path
|
||||
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def simple_async_add(a: int, b: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return a + b
|
||||
|
||||
os.environ['CODEFLASH_CURRENT_LINE_ID'] = 'simple_async_add_59'
|
||||
result = await simple_async_add(5, 3)
|
||||
|
||||
assert result == 8
|
||||
|
||||
assert temp_db_path.exists()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
|
||||
assert cur.fetchone() is not None
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
|
||||
(test_module_path, test_class_name, test_function_name, function_getting_tested,
|
||||
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
|
||||
|
||||
assert test_module_path == __name__
|
||||
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
|
||||
assert test_function_name == "test_behavior_async_basic_function"
|
||||
assert function_getting_tested == "simple_async_add"
|
||||
assert loop_index == 1
|
||||
# Line ID will be the actual line number from the source code, not a simple counter
|
||||
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
|
||||
assert runtime > 0
|
||||
assert verification_type == VerificationType.FUNCTION_CALL.value
|
||||
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, return_val = unpickled_data
|
||||
|
||||
assert args == (5, 3)
|
||||
assert kwargs == {}
|
||||
assert return_val == 8
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_divide(a: int, b: int) -> float:
|
||||
await asyncio.sleep(0.001)
|
||||
if b == 0:
|
||||
raise ValueError("Cannot divide by zero")
|
||||
return a / b
|
||||
|
||||
result = await async_divide(10, 2)
|
||||
assert result == 5.0
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot divide by zero"):
|
||||
await async_divide(10, 0)
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 2
|
||||
|
||||
success_row = rows[0]
|
||||
success_data = pickle.loads(success_row[7]) # return_value_blob
|
||||
args, kwargs, return_val = success_data
|
||||
assert args == (10, 2)
|
||||
assert return_val == 5.0
|
||||
|
||||
# Check exception record
|
||||
exception_row = rows[1]
|
||||
exception_data = pickle.loads(exception_row[7]) # return_value_blob
|
||||
assert isinstance(exception_data, ValueError)
|
||||
assert str(exception_data) == "Cannot divide by zero"
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
|
||||
"""Test performance async decorator doesn't store to database."""
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_multiply(a: int, b: int) -> int:
|
||||
"""Async function for performance testing."""
|
||||
await asyncio.sleep(0.002)
|
||||
return a * b
|
||||
|
||||
result = await async_multiply(4, 7)
|
||||
|
||||
assert result == 28
|
||||
|
||||
assert not temp_db_path.exists()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output_lines = captured.out.strip().split('\n')
|
||||
|
||||
assert len([line for line in output_lines if "!$######" in line]) == 1
|
||||
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
|
||||
|
||||
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
|
||||
assert "async_multiply" in closing_tag
|
||||
|
||||
timing_part = closing_tag.split(":")[-1].replace("######!", "")
|
||||
timing_value = int(timing_part)
|
||||
assert timing_value > 0 # Should have positive timing
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_increment(value: int) -> int:
|
||||
await asyncio.sleep(0.001)
|
||||
return value + 1
|
||||
|
||||
# Call the function multiple times
|
||||
results = []
|
||||
for i in range(3):
|
||||
result = await async_increment(i)
|
||||
results.append(result)
|
||||
|
||||
assert results == [1, 2, 3]
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3
|
||||
|
||||
actual_ids = [row[0] for row in rows]
|
||||
assert len(actual_ids) == 3
|
||||
|
||||
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
|
||||
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
|
||||
assert actual_ids == expected_pattern
|
||||
|
||||
for i, (_, return_value_blob) in enumerate(rows):
|
||||
args, kwargs, return_val = pickle.loads(return_value_blob)
|
||||
assert args == (i,)
|
||||
assert return_val == i + 1
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def complex_async_func(
|
||||
pos_arg: str,
|
||||
*args: int,
|
||||
keyword_arg: str = "default",
|
||||
**kwargs: str
|
||||
) -> dict:
|
||||
await asyncio.sleep(0.001)
|
||||
return {
|
||||
"pos_arg": pos_arg,
|
||||
"args": args,
|
||||
"keyword_arg": keyword_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
result = await complex_async_func(
|
||||
"hello",
|
||||
1, 2, 3,
|
||||
keyword_arg="custom",
|
||||
extra1="value1",
|
||||
extra2="value2"
|
||||
)
|
||||
|
||||
expected_result = {
|
||||
"pos_arg": "hello",
|
||||
"args": (1, 2, 3),
|
||||
"keyword_arg": "custom",
|
||||
"kwargs": {"extra1": "value1", "extra2": "value2"}
|
||||
}
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute("SELECT return_value FROM test_results")
|
||||
row = cur.fetchone()
|
||||
|
||||
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
|
||||
|
||||
assert stored_args == ("hello", 1, 2, 3)
|
||||
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
|
||||
assert stored_result == expected_result
|
||||
|
||||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def schema_test_func() -> str:
|
||||
return "schema_test"
|
||||
|
||||
await schema_test_func()
|
||||
|
||||
con = sqlite3.connect(temp_db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("PRAGMA table_info(test_results)")
|
||||
columns = cur.fetchall()
|
||||
|
||||
expected_columns = [
|
||||
(0, 'test_module_path', 'TEXT', 0, None, 0),
|
||||
(1, 'test_class_name', 'TEXT', 0, None, 0),
|
||||
(2, 'test_function_name', 'TEXT', 0, None, 0),
|
||||
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
|
||||
(4, 'loop_index', 'INTEGER', 0, None, 0),
|
||||
(5, 'iteration_id', 'TEXT', 0, None, 0),
|
||||
(6, 'runtime', 'INTEGER', 0, None, 0),
|
||||
(7, 'return_value', 'BLOB', 0, None, 0),
|
||||
(8, 'verification_type', 'TEXT', 0, None, 0)
|
||||
]
|
||||
|
||||
assert columns == expected_columns
|
||||
con.close()
|
||||
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -12,6 +12,7 @@ from codeflash.code_utils.code_replacer import (
|
|||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
replace_functions_in_file,
|
||||
OptimFunctionCollector,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
|
||||
|
|
@ -3449,156 +3450,173 @@ def hydrate_input_text_actions_with_field_names(
|
|||
|
||||
assert new_code == expected
|
||||
|
||||
def test_duplicate_global_assignments_when_reverting_helpers():
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
The pre-chunker's responsibilities are:
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# -- all detectors need to be called to update state and avoid double counting
|
||||
# -- boundaries that happen to coincide, like Table and new section on same element.
|
||||
# -- Using `any()` would short-circuit on first True.
|
||||
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
|
||||
return any(semantic_boundaries)
|
||||
'''
|
||||
main_file.write_text(original_code, encoding="utf-8")
|
||||
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
from __future__ import annotations
|
||||
from typing import Iterable
|
||||
from unstructured.documents.elements import Element
|
||||
from unstructured.utils import lazyproperty
|
||||
class PreChunker:
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
for pred in self._boundary_predicates:
|
||||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
```
|
||||
'''
|
||||
# OptimFunctionCollector async function tests
|
||||
def test_optim_function_collector_with_async_functions():
|
||||
"""Test OptimFunctionCollector correctly collects async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def sync_function():
|
||||
return "sync"
|
||||
|
||||
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
|
||||
test_config = TestConfig(
|
||||
tests_root=root_dir / "tests/pytest",
|
||||
tests_project_rootdir=root_dir,
|
||||
project_root_path=root_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
async def async_function():
|
||||
return "async"
|
||||
|
||||
class TestClass:
|
||||
def sync_method(self):
|
||||
return "sync_method"
|
||||
|
||||
async def async_method(self):
|
||||
return "async_method"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
|
||||
preexisting_objects=None
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect both sync and async functions
|
||||
assert len(collector.modified_functions) == 4
|
||||
assert (None, "sync_function") in collector.modified_functions
|
||||
assert (None, "async_function") in collector.modified_functions
|
||||
assert ("TestClass", "sync_method") in collector.modified_functions
|
||||
assert ("TestClass", "async_method") in collector.modified_functions
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
||||
def test_optim_function_collector_new_async_functions():
|
||||
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def existing_function():
|
||||
return "existing"
|
||||
|
||||
async def new_async_function():
|
||||
return "new_async"
|
||||
|
||||
def new_sync_function():
|
||||
return "new_sync"
|
||||
|
||||
class ExistingClass:
|
||||
async def new_class_async_method(self):
|
||||
return "new_class_async"
|
||||
"""
|
||||
|
||||
# Only existing_function is in preexisting objects
|
||||
preexisting_objects = {("existing_function", ())}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=set(), # Not looking for specific functions
|
||||
preexisting_objects=preexisting_objects
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should identify new functions (both sync and async)
|
||||
assert len(collector.new_functions) == 2
|
||||
function_names = [func.name.value for func in collector.new_functions]
|
||||
assert "new_async_function" in function_names
|
||||
assert "new_sync_function" in function_names
|
||||
|
||||
# Should identify new class methods
|
||||
assert "ExistingClass" in collector.new_class_functions
|
||||
assert len(collector.new_class_functions["ExistingClass"]) == 1
|
||||
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
|
||||
|
||||
|
||||
new_code = main_file.read_text(encoding="utf-8")
|
||||
main_file.unlink(missing_ok=True)
|
||||
def test_optim_function_collector_mixed_scenarios():
|
||||
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global functions
|
||||
def global_sync():
|
||||
pass
|
||||
|
||||
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
The pre-chunker's responsibilities are:
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
for pred in self._boundary_predicates:
|
||||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
async def global_async():
|
||||
pass
|
||||
|
||||
class ParentClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
pass
|
||||
|
||||
async def async_method(self):
|
||||
pass
|
||||
|
||||
class ChildClass:
|
||||
async def child_async_method(self):
|
||||
pass
|
||||
|
||||
def child_sync_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
# Looking for specific functions
|
||||
function_names = {
|
||||
(None, "global_sync"),
|
||||
(None, "global_async"),
|
||||
("ParentClass", "sync_method"),
|
||||
("ParentClass", "async_method"),
|
||||
("ChildClass", "child_async_method")
|
||||
}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=function_names,
|
||||
preexisting_objects=None
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect all specified functions (mix of sync and async)
|
||||
assert len(collector.modified_functions) == 5
|
||||
assert (None, "global_sync") in collector.modified_functions
|
||||
assert (None, "global_async") in collector.modified_functions
|
||||
assert ("ParentClass", "sync_method") in collector.modified_functions
|
||||
assert ("ParentClass", "async_method") in collector.modified_functions
|
||||
assert ("ChildClass", "child_async_method") in collector.modified_functions
|
||||
|
||||
# Should collect __init__ method
|
||||
assert "ParentClass" in collector.modified_init_functions
|
||||
|
||||
|
||||
|
||||
def test_is_zero_diff_async_sleep():
|
||||
original_code = '''
|
||||
import time
|
||||
|
||||
async def task():
|
||||
time.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert new_code == expected
|
||||
optimized_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert not is_zero_diff(original_code, optimized_code)
|
||||
|
||||
def test_is_zero_diff_with_equivalent_code():
|
||||
original_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
optimized_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
"""A task that does something."""
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert is_zero_diff(original_code, optimized_code)
|
||||
|
|
@ -17,10 +17,10 @@ from codeflash.code_utils.code_utils import (
|
|||
is_class_defined_in_file,
|
||||
module_name_from_file_path,
|
||||
path_belongs_to_site_packages,
|
||||
has_any_async_functions,
|
||||
validate_python_code,
|
||||
)
|
||||
from codeflash.code_utils.concolic_utils import clean_concolic_tests
|
||||
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
|
||||
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -254,7 +254,7 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
|
|||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages").resolve()]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
||||
file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
|
||||
|
|
@ -277,6 +277,66 @@ def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.Mo
|
|||
assert path_belongs_to_site_packages(file_path) is False
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_symlinked_site_packages(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
real_site_packages = tmp_path / "real_site_packages"
|
||||
real_site_packages.mkdir()
|
||||
|
||||
symlinked_site_packages = tmp_path / "symlinked_site_packages"
|
||||
symlinked_site_packages.symlink_to(real_site_packages)
|
||||
|
||||
package_file = real_site_packages / "some_package" / "__init__.py"
|
||||
package_file.parent.mkdir()
|
||||
package_file.write_text("# package file")
|
||||
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: [str(symlinked_site_packages)])
|
||||
|
||||
assert path_belongs_to_site_packages(package_file) is True
|
||||
|
||||
symlinked_package_file = symlinked_site_packages / "some_package" / "__init__.py"
|
||||
assert path_belongs_to_site_packages(symlinked_package_file) is True
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_complex_symlinks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
real_site_packages = tmp_path / "real" / "lib" / "python3.9" / "site-packages"
|
||||
real_site_packages.mkdir(parents=True)
|
||||
|
||||
link1 = tmp_path / "link1"
|
||||
link1.symlink_to(real_site_packages.parent.parent.parent)
|
||||
|
||||
link2 = tmp_path / "link2"
|
||||
link2.symlink_to(link1)
|
||||
|
||||
package_file = real_site_packages / "test_package" / "module.py"
|
||||
package_file.parent.mkdir()
|
||||
package_file.write_text("# test module")
|
||||
|
||||
site_packages_via_links = link2 / "lib" / "python3.9" / "site-packages"
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: [str(site_packages_via_links)])
|
||||
|
||||
assert path_belongs_to_site_packages(package_file) is True
|
||||
|
||||
file_via_links = site_packages_via_links / "test_package" / "module.py"
|
||||
assert path_belongs_to_site_packages(file_via_links) is True
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_resolved_paths_normalization(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
site_packages_dir = tmp_path / "lib" / "python3.9" / "site-packages"
|
||||
site_packages_dir.mkdir(parents=True)
|
||||
|
||||
package_dir = site_packages_dir / "mypackage"
|
||||
package_dir.mkdir()
|
||||
package_file = package_dir / "module.py"
|
||||
package_file.write_text("# module")
|
||||
|
||||
complex_site_packages_path = tmp_path / "lib" / "python3.9" / "other" / ".." / "site-packages" / "."
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: [str(complex_site_packages_path)])
|
||||
|
||||
assert path_belongs_to_site_packages(package_file) is True
|
||||
|
||||
complex_file_path = tmp_path / "lib" / "python3.9" / "site-packages" / "other" / ".." / "mypackage" / "module.py"
|
||||
assert path_belongs_to_site_packages(complex_file_path) is True
|
||||
|
||||
|
||||
# tests for is_class_defined_in_file
|
||||
def test_is_class_defined_in_file_with_existing_class(tmp_path: Path) -> None:
|
||||
test_file = tmp_path / "test_file.py"
|
||||
|
|
@ -308,6 +368,86 @@ def my_function():
|
|||
assert is_class_defined_in_file("MyClass", test_file) is False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_code_context():
|
||||
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
|
||||
from unittest.mock import MagicMock
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
context = MagicMock(spec=CodeOptimizationContext)
|
||||
context.preexisting_objects = []
|
||||
return context
|
||||
|
||||
|
||||
def test_extract_dependent_function_sync_and_async(mock_code_context):
|
||||
"""Test extract_dependent_function with both sync and async functions."""
|
||||
# Test sync function extraction
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
def helper_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
|
||||
|
||||
# Test async function extraction
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
async def async_helper_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
|
||||
|
||||
|
||||
def test_extract_dependent_function_edge_cases(mock_code_context):
|
||||
"""Test extract_dependent_function edge cases."""
|
||||
# No dependent functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
# Multiple dependent functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
def helper1():
|
||||
pass
|
||||
|
||||
async def helper2():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
|
||||
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
|
||||
"""Test extract_dependent_function with mixed sync/async scenarios."""
|
||||
# Async main with sync helper
|
||||
mock_code_context.testgen_context_code = """
|
||||
async def async_main():
|
||||
pass
|
||||
|
||||
def sync_helper():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
|
||||
|
||||
# Only async functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
async def async_main():
|
||||
pass
|
||||
|
||||
async def async_helper():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
|
||||
|
||||
|
||||
def test_is_class_defined_in_file_with_non_existing_file() -> None:
|
||||
non_existing_file = Path("/non/existing/file.py")
|
||||
|
||||
|
|
@ -445,25 +585,41 @@ def test_Grammar_copy():
|
|||
assert cleaned_code == expected_cleaned_code.strip()
|
||||
|
||||
|
||||
def test_has_any_async_functions_with_async_code() -> None:
|
||||
def test_validate_python_code_valid() -> None:
|
||||
code = "def hello():\n return 'world'"
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
||||
|
||||
def test_validate_python_code_invalid() -> None:
|
||||
code = "def hello(:\n return 'world'"
|
||||
with pytest.raises(ValueError, match="Invalid Python code"):
|
||||
validate_python_code(code)
|
||||
|
||||
|
||||
def test_validate_python_code_empty() -> None:
|
||||
code = ""
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
||||
|
||||
def test_validate_python_code_complex_invalid() -> None:
|
||||
code = "if True\n print('missing colon')"
|
||||
with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"):
|
||||
validate_python_code(code)
|
||||
|
||||
|
||||
def test_validate_python_code_valid_complex() -> None:
|
||||
code = """
|
||||
def normal_function():
|
||||
pass
|
||||
|
||||
async def async_function():
|
||||
pass
|
||||
def calculate(a, b):
|
||||
if a > b:
|
||||
return a + b
|
||||
else:
|
||||
return a * b
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
"""
|
||||
result = has_any_async_functions(code)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_has_any_async_functions_without_async_code() -> None:
|
||||
code = """
|
||||
def normal_function():
|
||||
pass
|
||||
|
||||
def another_function():
|
||||
pass
|
||||
"""
|
||||
result = has_any_async_functions(code)
|
||||
assert result is False
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
|||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
|
||||
"""
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -117,7 +117,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
|||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
|
||||
"""
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
||||
|
|
@ -129,7 +129,7 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -181,7 +181,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
|||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
|
@ -194,7 +194,7 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -261,7 +261,7 @@ class MyClass:
|
|||
def __init__(self):
|
||||
self.x = 2
|
||||
# Print out the detected test info each time we instantiate MyClass
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
|
||||
"""
|
||||
|
||||
test_file_name = "test_stack_info_recursive_temp.py"
|
||||
|
|
@ -279,7 +279,7 @@ class MyClass:
|
|||
|
||||
# Run pytest as a subprocess
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
|
|
@ -343,7 +343,7 @@ from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
|||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir!s}')}}|TEST_INFO_END")
|
||||
print(f"TEST_INFO_START|{{get_test_info_from_stack('{test_dir.as_posix()}')}}|TEST_INFO_END")
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
|
@ -356,7 +356,7 @@ class MyClass:
|
|||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
cwd=test_dir, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"], env=os.environ.copy()
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
|
|
@ -410,10 +410,11 @@ class TestUnittestExample(unittest.TestCase):
|
|||
self.assertTrue(True)
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
|
||||
sample_code = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
class MyClass:
|
||||
@codeflash_capture(function_name="some_function", tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", tests_root="{test_dir!s}")
|
||||
@codeflash_capture(function_name="some_function", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}")
|
||||
def __init__(self, x=2):
|
||||
self.x = x
|
||||
"""
|
||||
|
|
@ -528,6 +529,7 @@ class TestUnittestExample(unittest.TestCase):
|
|||
self.assertTrue(True)
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
|
||||
# MyClass did not have an init function, we created the init function with the codeflash_capture decorator using instrumentation
|
||||
sample_code = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
|
@ -536,7 +538,7 @@ class ParentClass:
|
|||
self.x = 2
|
||||
|
||||
class MyClass(ParentClass):
|
||||
@codeflash_capture(function_name="some_function", tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}", tests_root="{test_dir!s}")
|
||||
@codeflash_capture(function_name="some_function", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}")
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
"""
|
||||
|
|
@ -648,14 +650,15 @@ def test_example_test():
|
|||
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
|
||||
sample_code = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class MyClass:
|
||||
@codeflash_capture(
|
||||
function_name="some_function",
|
||||
tmp_dir_path="{get_run_tmp_file(Path("test_return_values"))}",
|
||||
tests_root="{test_dir!s}"
|
||||
tmp_dir_path="{tmp_dir_path.as_posix()}",
|
||||
tests_root="{test_dir.as_posix()}"
|
||||
)
|
||||
def __init__(self, x=2):
|
||||
self.x = x
|
||||
|
|
@ -765,13 +768,14 @@ def test_helper_classes():
|
|||
assert MyClass().target_function() == 6
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
|
||||
original_code = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
from code_to_optimize.tests.pytest.helper_file_1 import HelperClass1
|
||||
from code_to_optimize.tests.pytest.helper_file_2 import HelperClass2, AnotherHelperClass
|
||||
|
||||
class MyClass:
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}" , is_fto=True)
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}" , is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
|
|
@ -785,7 +789,7 @@ class MyClass:
|
|||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class HelperClass1:
|
||||
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False)
|
||||
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False)
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
|
||||
|
|
@ -797,7 +801,7 @@ class HelperClass1:
|
|||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class HelperClass2:
|
||||
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False)
|
||||
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False)
|
||||
def __init__(self):
|
||||
self.z = 2
|
||||
|
||||
|
|
@ -805,7 +809,7 @@ class HelperClass2:
|
|||
return 2
|
||||
|
||||
class AnotherHelperClass:
|
||||
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))}', tests_root="{test_dir!s}", is_fto=False)
|
||||
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{tmp_dir_path.as_posix()}', tests_root="{test_dir.as_posix()}", is_fto=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -787,6 +787,126 @@ def test_jax():
|
|||
assert not comparator(aa, cc)
|
||||
|
||||
|
||||
def test_xarray():
|
||||
try:
|
||||
import xarray as xr
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pytest.skip()
|
||||
|
||||
# Test basic DataArray
|
||||
a = xr.DataArray([1, 2, 3], dims=['x'])
|
||||
b = xr.DataArray([1, 2, 3], dims=['x'])
|
||||
c = xr.DataArray([1, 2, 4], dims=['x'])
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
# Test DataArray with coordinates
|
||||
d = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
|
||||
e = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 2]}, dims=['x'])
|
||||
f = xr.DataArray([1, 2, 3], coords={'x': [0, 1, 3]}, dims=['x'])
|
||||
assert comparator(d, e)
|
||||
assert not comparator(d, f)
|
||||
|
||||
# Test DataArray with attributes
|
||||
g = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
|
||||
h = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'meters'})
|
||||
i = xr.DataArray([1, 2, 3], dims=['x'], attrs={'units': 'feet'})
|
||||
assert comparator(g, h)
|
||||
assert not comparator(g, i)
|
||||
|
||||
# Test 2D DataArray
|
||||
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
|
||||
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=['x', 'y'])
|
||||
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=['x', 'y'])
|
||||
assert comparator(j, k)
|
||||
assert not comparator(j, l)
|
||||
|
||||
# Test DataArray with different dimensions
|
||||
m = xr.DataArray([1, 2, 3], dims=['x'])
|
||||
n = xr.DataArray([1, 2, 3], dims=['y'])
|
||||
assert not comparator(m, n)
|
||||
|
||||
# Test DataArray with NaN values
|
||||
o = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
|
||||
p = xr.DataArray([1.0, np.nan, 3.0], dims=['x'])
|
||||
q = xr.DataArray([1.0, 2.0, 3.0], dims=['x'])
|
||||
assert comparator(o, p)
|
||||
assert not comparator(o, q)
|
||||
|
||||
# Test Dataset
|
||||
r = xr.Dataset({
|
||||
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
|
||||
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
|
||||
})
|
||||
s = xr.Dataset({
|
||||
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
|
||||
'pressure': (['x', 'y'], [[5, 6], [7, 8]])
|
||||
})
|
||||
t = xr.Dataset({
|
||||
'temp': (['x', 'y'], [[1, 2], [3, 4]]),
|
||||
'pressure': (['x', 'y'], [[5, 6], [7, 9]])
|
||||
})
|
||||
assert comparator(r, s)
|
||||
assert not comparator(r, t)
|
||||
|
||||
# Test Dataset with coordinates
|
||||
u = xr.Dataset({
|
||||
'temp': (['x', 'y'], [[1, 2], [3, 4]])
|
||||
}, coords={'x': [0, 1], 'y': [0, 1]})
|
||||
v = xr.Dataset({
|
||||
'temp': (['x', 'y'], [[1, 2], [3, 4]])
|
||||
}, coords={'x': [0, 1], 'y': [0, 1]})
|
||||
w = xr.Dataset({
|
||||
'temp': (['x', 'y'], [[1, 2], [3, 4]])
|
||||
}, coords={'x': [0, 2], 'y': [0, 1]})
|
||||
assert comparator(u, v)
|
||||
assert not comparator(u, w)
|
||||
|
||||
# Test Dataset with attributes
|
||||
x = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
|
||||
y = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'sensor'})
|
||||
z = xr.Dataset({'temp': (['x'], [1, 2, 3])}, attrs={'source': 'model'})
|
||||
assert comparator(x, y)
|
||||
assert not comparator(x, z)
|
||||
|
||||
# Test Dataset with different variables
|
||||
aa = xr.Dataset({'temp': (['x'], [1, 2, 3])})
|
||||
bb = xr.Dataset({'temp': (['x'], [1, 2, 3])})
|
||||
cc = xr.Dataset({'pressure': (['x'], [1, 2, 3])})
|
||||
assert comparator(aa, bb)
|
||||
assert not comparator(aa, cc)
|
||||
|
||||
# Test empty Dataset
|
||||
dd = xr.Dataset()
|
||||
ee = xr.Dataset()
|
||||
assert comparator(dd, ee)
|
||||
|
||||
# Test DataArray with different shapes
|
||||
ff = xr.DataArray([1, 2, 3], dims=['x'])
|
||||
gg = xr.DataArray([[1, 2, 3]], dims=['x', 'y'])
|
||||
assert not comparator(ff, gg)
|
||||
|
||||
# Test DataArray with different data types
|
||||
# Note: xarray.identical() considers int and float arrays with same values as identical
|
||||
hh = xr.DataArray(np.array([1, 2, 3], dtype='int32'), dims=['x'])
|
||||
ii = xr.DataArray(np.array([1, 2, 3], dtype='int64'), dims=['x'])
|
||||
# xarray is permissive with dtype comparisons, treats these as identical
|
||||
assert comparator(hh, ii)
|
||||
|
||||
# Test DataArray with infinity
|
||||
jj = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
|
||||
kk = xr.DataArray([1.0, np.inf, 3.0], dims=['x'])
|
||||
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=['x'])
|
||||
assert comparator(jj, kk)
|
||||
assert not comparator(jj, ll)
|
||||
|
||||
# Test Dataset vs DataArray (different types)
|
||||
mm = xr.DataArray([1, 2, 3], dims=['x'])
|
||||
nn = xr.Dataset({'data': (['x'], [1, 2, 3])})
|
||||
assert not comparator(mm, nn)
|
||||
|
||||
|
||||
def test_returns():
|
||||
a = Success(5)
|
||||
b = Success(5)
|
||||
|
|
@ -1502,4 +1622,147 @@ def test_collections() -> None:
|
|||
d = "hello"
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
assert not comparator(a, d)
|
||||
assert not comparator(a, d)
|
||||
|
||||
|
||||
def test_attrs():
|
||||
try:
|
||||
import attrs # type: ignore
|
||||
except ImportError:
|
||||
pytest.skip()
|
||||
|
||||
@attrs.define
|
||||
class Person:
|
||||
name: str
|
||||
age: int = 10
|
||||
|
||||
a = Person("Alice", 25)
|
||||
b = Person("Alice", 25)
|
||||
c = Person("Bob", 25)
|
||||
d = Person("Alice", 30)
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
assert not comparator(a, d)
|
||||
|
||||
@attrs.frozen
|
||||
class Point:
|
||||
x: int
|
||||
y: int
|
||||
|
||||
p1 = Point(1, 2)
|
||||
p2 = Point(1, 2)
|
||||
p3 = Point(2, 3)
|
||||
assert comparator(p1, p2)
|
||||
assert not comparator(p1, p3)
|
||||
|
||||
@attrs.define(slots=True)
|
||||
class Vehicle:
|
||||
brand: str
|
||||
model: str
|
||||
year: int = 2020
|
||||
|
||||
v1 = Vehicle("Toyota", "Camry", 2021)
|
||||
v2 = Vehicle("Toyota", "Camry", 2021)
|
||||
v3 = Vehicle("Honda", "Civic", 2021)
|
||||
assert comparator(v1, v2)
|
||||
assert not comparator(v1, v3)
|
||||
|
||||
@attrs.define
|
||||
class ComplexClass:
|
||||
public_field: str
|
||||
private_field: str = attrs.field(repr=False)
|
||||
non_eq_field: int = attrs.field(eq=False, default=0)
|
||||
computed: str = attrs.field(init=False, eq=True)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.computed = f"{self.public_field}_{self.private_field}"
|
||||
|
||||
c1 = ComplexClass("test", "secret")
|
||||
c2 = ComplexClass("test", "secret")
|
||||
c3 = ComplexClass("different", "secret")
|
||||
|
||||
c1.non_eq_field = 100
|
||||
c2.non_eq_field = 200
|
||||
|
||||
assert comparator(c1, c2)
|
||||
assert not comparator(c1, c3)
|
||||
|
||||
@attrs.define
|
||||
class Address:
|
||||
street: str
|
||||
city: str
|
||||
|
||||
@attrs.define
|
||||
class PersonWithAddress:
|
||||
name: str
|
||||
address: Address
|
||||
|
||||
addr1 = Address("123 Main St", "Anytown")
|
||||
addr2 = Address("123 Main St", "Anytown")
|
||||
addr3 = Address("456 Oak Ave", "Anytown")
|
||||
|
||||
person1 = PersonWithAddress("John", addr1)
|
||||
person2 = PersonWithAddress("John", addr2)
|
||||
person3 = PersonWithAddress("John", addr3)
|
||||
|
||||
assert comparator(person1, person2)
|
||||
assert not comparator(person1, person3)
|
||||
|
||||
@attrs.define
|
||||
class Container:
|
||||
items: list
|
||||
metadata: dict
|
||||
|
||||
cont1 = Container([1, 2, 3], {"type": "numbers"})
|
||||
cont2 = Container([1, 2, 3], {"type": "numbers"})
|
||||
cont3 = Container([1, 2, 4], {"type": "numbers"})
|
||||
|
||||
assert comparator(cont1, cont2)
|
||||
assert not comparator(cont1, cont3)
|
||||
|
||||
@attrs.define
|
||||
class BaseClass:
|
||||
name: str
|
||||
value: int
|
||||
|
||||
@attrs.define
|
||||
class ExtendedClass:
|
||||
name: str
|
||||
value: int
|
||||
extra_field: str = "default"
|
||||
|
||||
base = BaseClass("test", 42)
|
||||
extended = ExtendedClass("test", 42, "extra")
|
||||
|
||||
assert not comparator(base, extended)
|
||||
|
||||
@attrs.define
|
||||
class WithNonEqFields:
|
||||
name: str
|
||||
timestamp: float = attrs.field(eq=False) # Should be ignored
|
||||
debug_info: str = attrs.field(eq=False, default="debug")
|
||||
|
||||
obj1 = WithNonEqFields("test", 1000.0, "info1")
|
||||
obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields
|
||||
obj3 = WithNonEqFields("different", 1000.0, "info1")
|
||||
|
||||
assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info
|
||||
assert not comparator(obj1, obj3) # Should be different due to name
|
||||
@attrs.define
|
||||
class MinimalClass:
|
||||
name: str
|
||||
value: int
|
||||
|
||||
@attrs.define
|
||||
class ExtendedClass:
|
||||
name: str
|
||||
value: int
|
||||
extra_field: str = "default"
|
||||
metadata: dict = attrs.field(factory=dict)
|
||||
timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored
|
||||
|
||||
minimal = MinimalClass("test", 42)
|
||||
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
|
||||
|
||||
assert not comparator(minimal, extended)
|
||||
|
||||
|
|
@ -14,7 +14,13 @@ from codeflash.models.models import (
|
|||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.critic import (
|
||||
coverage_critic,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
throughput_gain,
|
||||
)
|
||||
|
||||
|
||||
def test_performance_gain() -> None:
|
||||
|
|
@ -429,3 +435,159 @@ def test_coverage_critic() -> None:
|
|||
)
|
||||
|
||||
assert coverage_critic(unittest_coverage, "unittest") is True
|
||||
|
||||
|
||||
def test_throughput_gain() -> None:
|
||||
"""Test throughput_gain calculation."""
|
||||
# Test basic throughput improvement
|
||||
assert throughput_gain(original_throughput=100, optimized_throughput=150) == 0.5 # 50% improvement
|
||||
|
||||
# Test no improvement
|
||||
assert throughput_gain(original_throughput=100, optimized_throughput=100) == 0.0
|
||||
|
||||
# Test regression
|
||||
assert throughput_gain(original_throughput=100, optimized_throughput=80) == -0.2 # 20% regression
|
||||
|
||||
# Test zero original throughput (edge case)
|
||||
assert throughput_gain(original_throughput=0, optimized_throughput=50) == 0.0
|
||||
|
||||
# Test large improvement
|
||||
assert throughput_gain(original_throughput=50, optimized_throughput=200) == 3.0 # 300% improvement
|
||||
|
||||
|
||||
def test_speedup_critic_with_async_throughput() -> None:
|
||||
"""Test speedup_critic with async throughput evaluation."""
|
||||
original_code_runtime = 10000 # 10 microseconds
|
||||
original_async_throughput = 100
|
||||
|
||||
# Test case 1: Both runtime and throughput improve significantly
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=120, # 20% throughput improvement
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 2: Runtime improves significantly, throughput doesn't meet threshold (should pass)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=105, # Only 5% throughput improvement (below 10% threshold)
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 3: Throughput improves significantly, runtime doesn't meet threshold (should pass)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=9800, # Only 2% runtime improvement (below 5% threshold)
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=9800,
|
||||
async_throughput=120, # 20% throughput improvement
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 4: No throughput data - should fall back to runtime-only evaluation
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=None, # No throughput data
|
||||
)
|
||||
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=None, # No original throughput data
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 5: Test best_throughput_until_now comparison
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=115, # 15% throughput improvement
|
||||
)
|
||||
|
||||
# Should pass when no best throughput yet
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Should fail when there's a better throughput already
|
||||
assert not speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=7000, # Better runtime already exists
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=120, # Better throughput already exists
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
# Test case 6: Zero original throughput (edge case)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=50,
|
||||
)
|
||||
|
||||
# Should pass when original throughput is 0 (throughput evaluation skipped)
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=0, # Zero original throughput
|
||||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,11 @@ from codeflash.models.models import CodeString, CodeStringsMarkdown
|
|||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield Path(tmpdirname)
|
||||
|
||||
def test_remove_duplicate_imports():
|
||||
"""Test that duplicate imports are removed when should_sort_imports is True."""
|
||||
original_code = "import os\nimport os\n"
|
||||
|
|
@ -37,17 +42,15 @@ def test_sorting_imports():
|
|||
assert new_code == "import os\nimport sys\nimport unittest\n"
|
||||
|
||||
|
||||
def test_sort_imports_without_formatting():
|
||||
def test_sort_imports_without_formatting(temp_dir):
|
||||
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
tmp.write(b"import sys\nimport unittest\nimport os\n")
|
||||
tmp.flush()
|
||||
tmp_path = Path(tmp.name)
|
||||
temp_file = temp_dir / "test_file.py"
|
||||
temp_file.write_text("import sys\nimport unittest\nimport os\n")
|
||||
|
||||
new_code = format_code(formatter_cmds=["disabled"], path=tmp_path)
|
||||
assert new_code is not None
|
||||
new_code = sort_imports(new_code)
|
||||
assert new_code == "import os\nimport sys\nimport unittest\n"
|
||||
new_code = format_code(formatter_cmds=["disabled"], path=temp_file)
|
||||
assert new_code is not None
|
||||
new_code = sort_imports(new_code)
|
||||
assert new_code == "import os\nimport sys\nimport unittest\n"
|
||||
|
||||
|
||||
def test_dedup_and_sort_imports_deduplicates():
|
||||
|
|
@ -101,7 +104,7 @@ def foo():
|
|||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_cmds_non_existent():
|
||||
def test_formatter_cmds_non_existent(temp_dir):
|
||||
"""Test that default formatter-cmds is used when it doesn't exist in the toml."""
|
||||
config_data = """
|
||||
[tool.codeflash]
|
||||
|
|
@ -110,113 +113,99 @@ tests-root = "tests"
|
|||
test-framework = "pytest"
|
||||
ignore-paths = []
|
||||
"""
|
||||
config_file = temp_dir / "config.toml"
|
||||
config_file.write_text(config_data)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp:
|
||||
tmp.write(config_data.encode())
|
||||
tmp.flush()
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
config, _ = parse_config_file(tmp_path)
|
||||
assert config["formatter_cmds"] == ["black $file"]
|
||||
finally:
|
||||
os.remove(tmp_path)
|
||||
config, _ = parse_config_file(config_file)
|
||||
assert config["formatter_cmds"] == ["black $file"]
|
||||
|
||||
try:
|
||||
import black
|
||||
except ImportError:
|
||||
pytest.skip("black is not installed")
|
||||
|
||||
original_code = b"""
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
expected = """import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], "bar")
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
tmp.write(original_code)
|
||||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
|
||||
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_black():
|
||||
try:
|
||||
import black
|
||||
except ImportError:
|
||||
pytest.skip("black is not installed")
|
||||
original_code = b"""
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
expected = """import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], "bar")
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
tmp.write(original_code)
|
||||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
|
||||
actual = format_code(formatter_cmds=["black $file"], path=Path(tmp_path))
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_ruff():
|
||||
try:
|
||||
import ruff # type: ignore
|
||||
except ImportError:
|
||||
pytest.skip("ruff is not installed")
|
||||
original_code = b"""
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
expected = """import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], "bar")
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".py") as tmp:
|
||||
tmp.write(original_code)
|
||||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
|
||||
actual = format_code(
|
||||
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=Path(tmp_path)
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_error():
|
||||
original_code = """
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
with tempfile.NamedTemporaryFile("w") as tmp:
|
||||
tmp.write(original_code)
|
||||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
new_code = format_code(formatter_cmds=["exit 1"], path=Path(tmp_path), exit_on_failure=False)
|
||||
assert new_code == original_code
|
||||
except Exception as e:
|
||||
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
|
||||
expected = """import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], \"bar\")
|
||||
"""
|
||||
temp_file = temp_dir / "test_file.py"
|
||||
temp_file.write_text(original_code)
|
||||
|
||||
actual = format_code(formatter_cmds=["black $file"], path=temp_file)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_black(temp_dir):
|
||||
try:
|
||||
import black
|
||||
except ImportError:
|
||||
pytest.skip("black is not installed")
|
||||
original_code = """
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
expected = """import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], \"bar\")
|
||||
"""
|
||||
temp_file = temp_dir / "test_file.py"
|
||||
temp_file.write_text(original_code)
|
||||
|
||||
actual = format_code(formatter_cmds=["black $file"], path=temp_file)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_ruff(temp_dir):
|
||||
try:
|
||||
import ruff # type: ignore
|
||||
except ImportError:
|
||||
pytest.skip("ruff is not installed")
|
||||
original_code = """
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
expected = """import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], \"bar\")
|
||||
"""
|
||||
temp_file = temp_dir / "test_file.py"
|
||||
temp_file.write_text(original_code)
|
||||
|
||||
actual = format_code(
|
||||
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_formatter_error(tmp_path: Path):
|
||||
original_code = """
|
||||
import os
|
||||
import sys
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')"""
|
||||
temp_file = tmp_path / "test_formatter_error.py"
|
||||
temp_file.write_text(original_code, encoding="utf-8")
|
||||
try:
|
||||
new_code = format_code(formatter_cmds=["exit 1"], path=temp_file, exit_on_failure=False)
|
||||
assert new_code == original_code
|
||||
except Exception as e:
|
||||
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
|
||||
|
||||
|
||||
def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ def test_function_eligible_for_optimization() -> None:
|
|||
return a**2
|
||||
"""
|
||||
functions_found = {}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(function)
|
||||
f.flush()
|
||||
functions_found = find_all_functions_in_file(Path(f.name))
|
||||
assert functions_found[Path(f.name)][0].function_name == "test_function_eligible_for_optimization"
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(function)
|
||||
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
assert functions_found[file_path][0].function_name == "test_function_eligible_for_optimization"
|
||||
|
||||
# Has no return statement
|
||||
function = """def test_function_not_eligible_for_optimization():
|
||||
|
|
@ -33,28 +37,40 @@ def test_function_eligible_for_optimization() -> None:
|
|||
print(a)
|
||||
"""
|
||||
functions_found = {}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(function)
|
||||
f.flush()
|
||||
functions_found = find_all_functions_in_file(Path(f.name))
|
||||
assert len(functions_found[Path(f.name)]) == 0
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(function)
|
||||
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
assert len(functions_found[file_path]) == 0
|
||||
|
||||
|
||||
# we want to trigger an error in the function discovery
|
||||
function = """def test_invalid_code():"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(function)
|
||||
f.flush()
|
||||
functions_found = find_all_functions_in_file(Path(f.name))
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(function)
|
||||
|
||||
functions_found = find_all_functions_in_file(file_path)
|
||||
assert functions_found == {}
|
||||
|
||||
|
||||
|
||||
|
||||
def test_find_top_level_function_or_method():
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(
|
||||
"""def functionA():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""def functionA():
|
||||
def functionB():
|
||||
return 5
|
||||
class E:
|
||||
|
|
@ -76,42 +92,48 @@ class AirbyteEntrypoint(object):
|
|||
def non_classmethod_function(cls, name):
|
||||
return cls.name
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
path_obj_name = Path(f.name)
|
||||
assert inspect_top_level_functions_or_methods(path_obj_name, "functionA").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionB").is_top_level
|
||||
assert inspect_top_level_functions_or_methods(path_obj_name, "functionC", class_name="A").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionD", class_name="A").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionF", class_name="E").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA").has_args
|
||||
)
|
||||
|
||||
assert inspect_top_level_functions_or_methods(file_path, "functionA").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionB").is_top_level
|
||||
assert inspect_top_level_functions_or_methods(file_path, "functionC", class_name="A").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionD", class_name="A").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionF", class_name="E").is_top_level
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionA").has_args
|
||||
staticmethod_func = inspect_top_level_functions_or_methods(
|
||||
path_obj_name, "handle_record_counts", class_name=None, line_no=15
|
||||
file_path, "handle_record_counts", class_name=None, line_no=15
|
||||
)
|
||||
assert staticmethod_func.is_staticmethod
|
||||
assert staticmethod_func.staticmethod_class_name == "AirbyteEntrypoint"
|
||||
assert inspect_top_level_functions_or_methods(
|
||||
path_obj_name, "functionE", class_name="AirbyteEntrypoint"
|
||||
file_path, "functionE", class_name="AirbyteEntrypoint"
|
||||
).is_classmethod
|
||||
assert not inspect_top_level_functions_or_methods(
|
||||
path_obj_name, "non_classmethod_function", class_name="AirbyteEntrypoint"
|
||||
file_path, "non_classmethod_function", class_name="AirbyteEntrypoint"
|
||||
).is_top_level
|
||||
# needed because this will be traced with a class_name being passed
|
||||
|
||||
# we want to write invalid code to ensure that the function discovery does not crash
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(
|
||||
"""def functionA():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""def functionA():
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
path_obj_name = Path(f.name)
|
||||
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA")
|
||||
)
|
||||
|
||||
assert not inspect_top_level_functions_or_methods(file_path, "functionA")
|
||||
|
||||
def test_class_method_discovery():
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(
|
||||
"""class A:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""class A:
|
||||
def functionA():
|
||||
return True
|
||||
def functionB():
|
||||
|
|
@ -123,21 +145,20 @@ class X:
|
|||
return False
|
||||
def functionA():
|
||||
return True"""
|
||||
)
|
||||
f.flush()
|
||||
)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
path_obj_name = Path(f.name)
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=path_obj_name,
|
||||
file=file_path,
|
||||
only_get_this_function="A.functionA",
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[Path("/bruh/")],
|
||||
project_root=path_obj_name.parent,
|
||||
module_root=path_obj_name.parent,
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
assert len(functions) == 1
|
||||
for file in functions:
|
||||
|
|
@ -148,12 +169,12 @@ def functionA():
|
|||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=path_obj_name,
|
||||
file=file_path,
|
||||
only_get_this_function="X.functionA",
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[Path("/bruh/")],
|
||||
project_root=path_obj_name.parent,
|
||||
module_root=path_obj_name.parent,
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
assert len(functions) == 1
|
||||
for file in functions:
|
||||
|
|
@ -164,12 +185,12 @@ def functionA():
|
|||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=path_obj_name,
|
||||
file=file_path,
|
||||
only_get_this_function="functionA",
|
||||
test_cfg=test_config,
|
||||
ignore_paths=[Path("/bruh/")],
|
||||
project_root=path_obj_name.parent,
|
||||
module_root=path_obj_name.parent,
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
assert len(functions) == 1
|
||||
for file in functions:
|
||||
|
|
@ -178,8 +199,12 @@ def functionA():
|
|||
|
||||
|
||||
def test_nested_function():
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
import copy
|
||||
|
||||
|
|
@ -223,28 +248,31 @@ def propagate_attributes(
|
|||
traverse(source_node_id)
|
||||
return modified_nodes
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
path_obj_name = Path(f.name)
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=path_obj_name,
|
||||
file=file_path,
|
||||
test_cfg=test_config,
|
||||
only_get_this_function=None,
|
||||
ignore_paths=[Path("/bruh/")],
|
||||
project_root=path_obj_name.parent,
|
||||
module_root=path_obj_name.parent,
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions_count == 1
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
def outer_function():
|
||||
def inner_function():
|
||||
|
|
@ -252,28 +280,31 @@ def outer_function():
|
|||
|
||||
return inner_function
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
path_obj_name = Path(f.name)
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=path_obj_name,
|
||||
file=file_path,
|
||||
test_cfg=test_config,
|
||||
only_get_this_function=None,
|
||||
ignore_paths=[Path("/bruh/")],
|
||||
project_root=path_obj_name.parent,
|
||||
module_root=path_obj_name.parent,
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions_count == 1
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
file_path = temp_dir_path / "test_function.py"
|
||||
|
||||
with file_path.open("w") as f:
|
||||
f.write(
|
||||
"""
|
||||
def outer_function():
|
||||
def inner_function():
|
||||
|
|
@ -283,21 +314,20 @@ def outer_function():
|
|||
pass
|
||||
return inner_function, another_inner_function
|
||||
"""
|
||||
)
|
||||
f.flush()
|
||||
)
|
||||
|
||||
test_config = TestConfig(
|
||||
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
|
||||
)
|
||||
path_obj_name = Path(f.name)
|
||||
functions, functions_count, _ = get_functions_to_optimize(
|
||||
optimize_all=None,
|
||||
replay_test=None,
|
||||
file=path_obj_name,
|
||||
file=file_path,
|
||||
test_cfg=test_config,
|
||||
only_get_this_function=None,
|
||||
ignore_paths=[Path("/bruh/")],
|
||||
project_root=path_obj_name.parent,
|
||||
module_root=path_obj_name.parent,
|
||||
project_root=file_path.parent,
|
||||
module_root=file_path.parent,
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
|
|
|
|||
|
|
@ -3,13 +3,20 @@ import tempfile
|
|||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield Path(tmpdirname)
|
||||
|
||||
|
||||
def test_get_code_function() -> None:
|
||||
def test_get_code_function(temp_dir: Path) -> None:
|
||||
code = """def test(self):
|
||||
return self._test"""
|
||||
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
|
|
@ -18,14 +25,14 @@ def test_get_code_function() -> None:
|
|||
assert contextual_dunder_methods == set()
|
||||
|
||||
|
||||
def test_get_code_property() -> None:
|
||||
def test_get_code_property(temp_dir: Path) -> None:
|
||||
code = """class TestClass:
|
||||
def __init__(self):
|
||||
self._test = 5
|
||||
@property
|
||||
def test(self):
|
||||
return self._test"""
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
|
|
@ -36,7 +43,7 @@ def test_get_code_property() -> None:
|
|||
assert contextual_dunder_methods == {("TestClass", "__init__")}
|
||||
|
||||
|
||||
def test_get_code_class() -> None:
|
||||
def test_get_code_class(temp_dir: Path) -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
|
|
@ -54,7 +61,7 @@ class TestClass:
|
|||
@property
|
||||
def test(self):
|
||||
return self._test"""
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
|
|
@ -65,7 +72,7 @@ class TestClass:
|
|||
assert contextual_dunder_methods == {("TestClass", "__init__")}
|
||||
|
||||
|
||||
def test_get_code_bubble_sort_class() -> None:
|
||||
def test_get_code_bubble_sort_class(temp_dir: Path) -> None:
|
||||
code = """
|
||||
def hi():
|
||||
pass
|
||||
|
|
@ -105,7 +112,7 @@ class BubbleSortClass:
|
|||
arr[j + 1] = temp
|
||||
return arr
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
|
|
@ -116,7 +123,7 @@ class BubbleSortClass:
|
|||
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
|
||||
|
||||
|
||||
def test_get_code_indent() -> None:
|
||||
def test_get_code_indent(temp_dir: Path) -> None:
|
||||
code = """def hi():
|
||||
pass
|
||||
|
||||
|
|
@ -168,7 +175,7 @@ def non():
|
|||
def helper(self, arr, j):
|
||||
return arr[j] > arr[j + 1]
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
|
|
@ -198,7 +205,7 @@ def non():
|
|||
def unsorter(self, arr):
|
||||
return shuffle(arr)
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
|
|
@ -212,7 +219,7 @@ def non():
|
|||
assert contextual_dunder_methods == {("BubbleSortClass", "__init__"), ("BubbleSortClass", "__call__")}
|
||||
|
||||
|
||||
def test_get_code_multiline_class_def() -> None:
|
||||
def test_get_code_multiline_class_def(temp_dir: Path) -> None:
|
||||
code = """class StatementAssignmentVariableConstantMutable(
|
||||
StatementAssignmentVariableMixin, StatementAssignmentVariableConstantMutableBase
|
||||
):
|
||||
|
|
@ -235,7 +242,7 @@ def test_get_code_multiline_class_def() -> None:
|
|||
def computeStatement(self, trace_collection):
|
||||
return self, None, None
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
|
|
@ -252,13 +259,13 @@ def test_get_code_multiline_class_def() -> None:
|
|||
assert contextual_dunder_methods == set()
|
||||
|
||||
|
||||
def test_get_code_dataclass_attribute():
|
||||
def test_get_code_dataclass_attribute(temp_dir: Path) -> None:
|
||||
code = """@dataclass
|
||||
class CustomDataClass:
|
||||
name: str = ""
|
||||
data: List[int] = field(default_factory=list)"""
|
||||
|
||||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
with (temp_dir / "temp_file.py").open(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
|
|
@ -269,4 +276,4 @@ class CustomDataClass:
|
|||
[FunctionToOptimize("name", f.name, [FunctionParent("CustomDataClass", "ClassDef")])]
|
||||
)
|
||||
assert new_code is None
|
||||
assert contextual_dunder_methods == set()
|
||||
assert contextual_dunder_methods == set()
|
||||
|
|
@ -213,11 +213,12 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
lifespan=self.__duration__,
|
||||
)
|
||||
'''
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
file_path = Path(f.name).resolve()
|
||||
project_root_path = file_path.parent.resolve()
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
tempdir_path = Path(tempdir)
|
||||
file_path = (tempdir_path / "typed_code_helper.py").resolve()
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
project_root_path = tempdir_path.resolve()
|
||||
project_root_path = tempdir_path.resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="__call__",
|
||||
file_path=file_path,
|
||||
|
|
@ -440,4 +441,4 @@ def sorter_deps(arr):
|
|||
code_context.helper_functions[0].fully_qualified_name
|
||||
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
|
||||
)
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"
|
||||
|
|
@ -123,7 +123,7 @@ def test_sort():
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
|
||||
with test_path.open("w") as f:
|
||||
|
|
@ -276,16 +276,16 @@ def test_sort():
|
|||
fto = FunctionToOptimize(
|
||||
function_name="sorter", parents=[FunctionParent(name="BubbleSorter", type="ClassDef")], file_path=Path(fto_path)
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmp_test_path = Path(tmpdirname) / "test_class_method_behavior_results_temp.py"
|
||||
tmp_test_path.write_text(code, encoding="utf-8")
|
||||
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest"
|
||||
tmp_test_path, [CodePosition(7, 13), CodePosition(12, 13)], fto, tmp_test_path.parent, "pytest"
|
||||
)
|
||||
assert success
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
module_path=tmp_test_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
|
||||
test_path = tests_root / "test_class_method_behavior_results_temp.py"
|
||||
|
|
@ -295,7 +295,7 @@ def test_sort():
|
|||
try:
|
||||
new_test = expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_class_method_behavior_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
)
|
||||
|
||||
with test_path.open("w") as f:
|
||||
|
|
@ -486,4 +486,4 @@ class BubbleSorter:
|
|||
finally:
|
||||
fto_path.write_text(original_code, "utf-8")
|
||||
test_path.unlink(missing_ok=True)
|
||||
test_path_perf.unlink(missing_ok=True)
|
||||
test_path_perf.unlink(missing_ok=True)
|
||||
807
tests/test_instrument_async_tests.py
Normal file
807
tests/test_instrument_async_tests.py
Normal file
|
|
@ -0,0 +1,807 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
add_async_decorator_to_function,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, TestingMode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def unique_test_iteration():
|
||||
# """Provide a unique test iteration ID and clean up database after test."""
|
||||
# # Generate unique iteration ID
|
||||
# iteration_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# # Store original environment variable
|
||||
# original_iteration = os.environ.get("CODEFLASH_TEST_ITERATION")
|
||||
|
||||
# # Set unique iteration for this test
|
||||
# os.environ["CODEFLASH_TEST_ITERATION"] = iteration_id
|
||||
|
||||
# try:
|
||||
# yield iteration_id
|
||||
# finally:
|
||||
# # Cleanup: restore original environment and delete database file
|
||||
# if original_iteration is not None:
|
||||
# os.environ["CODEFLASH_TEST_ITERATION"] = original_iteration
|
||||
# elif "CODEFLASH_TEST_ITERATION" in os.environ:
|
||||
# del os.environ["CODEFLASH_TEST_ITERATION"]
|
||||
|
||||
# # Clean up database file
|
||||
# try:
|
||||
# from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
# db_path = get_run_tmp_file(Path(f"test_return_values_{iteration_id}.sqlite"))
|
||||
# if db_path.exists():
|
||||
# db_path.unlink()
|
||||
# except Exception:
|
||||
# pass # Ignore cleanup errors
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_application_behavior_mode():
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_application_performance_mode():
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_performance_async
|
||||
|
||||
|
||||
@codeflash_performance_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_function_code, func, TestingMode.PERFORMANCE)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_class_method_decorator_application():
|
||||
async_class_code = '''
|
||||
import asyncio
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class Calculator:
|
||||
"""Test class with async methods."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_method(self, a: int, b: int) -> int:
|
||||
"""Async method in class."""
|
||||
await asyncio.sleep(0.005)
|
||||
return a ** b
|
||||
|
||||
def sync_method(self, a: int, b: int) -> int:
|
||||
"""Sync method in class."""
|
||||
return a - b
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_method",
|
||||
file_path=Path("test_async.py"),
|
||||
parents=[{"name": "Calculator", "type": "ClassDef"}],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(async_class_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_no_duplicate_application():
|
||||
already_decorated_code = '''
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
import asyncio
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_reformatted_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Already decorated async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(already_decorated_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert not decorator_added
|
||||
assert modified_code.strip() == expected_reformatted_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inject_profiling_async_function_behavior_mode(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
async_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
"""Test async function behavior."""
|
||||
result = await async_function(5, 3)
|
||||
assert result == 15
|
||||
|
||||
result2 = await async_function(2, 4)
|
||||
assert result2 == 8
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_test_code)
|
||||
|
||||
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
|
||||
|
||||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success is True
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# For async functions, once source is decorated, test injection should fail
|
||||
# This is expected behavior - async instrumentation happens at the decorator level
|
||||
assert success is False
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inject_profiling_async_function_performance_mode(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
# Create the test file
|
||||
async_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function():
|
||||
"""Test async function performance."""
|
||||
result = await async_function(5, 3)
|
||||
assert result == 15
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_test_code)
|
||||
|
||||
func = FunctionToOptimize(function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True)
|
||||
|
||||
# First instrument the source module
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
assert source_success is True
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_performance_async" in instrumented_source
|
||||
# Check for the import with line continuation formatting
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_performance_async" in instrumented_source
|
||||
|
||||
# Write the instrumented source back
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
# Now test the full pipeline with source module path
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, [CodePosition(8, 18)], func, temp_dir, "pytest", mode=TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
# For async functions, once source is decorated, test injection should fail
|
||||
# This is expected behavior - async instrumentation happens at the decorator level
|
||||
assert success is False
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_mixed_sync_async_instrumentation(temp_dir):
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
def sync_function(x: int, y: int) -> int:
|
||||
"""Regular sync function."""
|
||||
return x * y
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
mixed_test_code = '''
|
||||
import asyncio
|
||||
import pytest
|
||||
from my_module import sync_function, async_function
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_functions():
|
||||
"""Test both sync and async functions."""
|
||||
sync_result = sync_function(10, 5)
|
||||
assert sync_result == 50
|
||||
|
||||
async_result = await async_function(3, 4)
|
||||
assert async_result == 12
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_mixed.py"
|
||||
test_file.write_text(mixed_test_code)
|
||||
|
||||
async_func = FunctionToOptimize(
|
||||
function_name="async_function", parents=[], file_path=Path("my_module.py"), is_async=True
|
||||
)
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, async_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
assert "from codeflash.code_utils.codeflash_wrap_decorator import" in instrumented_source
|
||||
assert "codeflash_behavior_async" in instrumented_source
|
||||
# Sync function should remain unchanged
|
||||
assert "def sync_function(x: int, y: int) -> int:" in instrumented_source
|
||||
|
||||
# Write instrumented source back
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file,
|
||||
[CodePosition(8, 18), CodePosition(11, 19)],
|
||||
async_func,
|
||||
temp_dir,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
|
||||
# Async functions should not be instrumented at the test level
|
||||
assert not success
|
||||
assert instrumented_test_code is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_function_qualified_name_handling():
|
||||
nested_async_code = '''
|
||||
import asyncio
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
"""Nested async method."""
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="nested_async_method",
|
||||
file_path=Path("test_nested.py"),
|
||||
parents=[{"name": "OuterClass", "type": "ClassDef"}, {"name": "InnerClass", "type": "ClassDef"}],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(nested_async_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
expected_output = (
|
||||
"""import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_behavior_async
|
||||
|
||||
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
@codeflash_behavior_async
|
||||
async def nested_async_method(self, x: int) -> int:
|
||||
\"\"\"Nested async method.\"\"\"
|
||||
await asyncio.sleep(0.001)
|
||||
return x * 2
|
||||
"""
|
||||
)
|
||||
|
||||
assert modified_code.strip() == expected_output.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_with_existing_decorators():
|
||||
"""Test async decorator application when function already has other decorators."""
|
||||
decorated_async_code = '''
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
def my_decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
@my_decorator
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Async function with existing decorator."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=Path("test_async.py"), parents=[], is_async=True
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(decorated_async_code, func, TestingMode.BEHAVIOR)
|
||||
|
||||
assert decorator_added
|
||||
# Should add codeflash decorator above existing decorators
|
||||
assert "@codeflash_behavior_async" in modified_code
|
||||
assert "@my_decorator" in modified_code
|
||||
# Codeflash decorator should come first
|
||||
codeflash_pos = modified_code.find("@codeflash_behavior_async")
|
||||
my_decorator_pos = modified_code.find("@my_decorator")
|
||||
assert codeflash_pos < my_decorator_pos
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_sync_function_not_affected_by_async_logic():
|
||||
sync_function_code = '''
|
||||
def sync_function(x: int, y: int) -> int:
|
||||
"""Regular sync function."""
|
||||
return x + y
|
||||
'''
|
||||
|
||||
sync_func = FunctionToOptimize(
|
||||
function_name="sync_function",
|
||||
file_path=Path("test_sync.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
modified_code, decorator_added = add_async_decorator_to_function(
|
||||
sync_function_code, sync_func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert not decorator_added
|
||||
assert modified_code == sync_function_code
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_inject_profiling_async_multiple_calls_same_test(temp_dir):
|
||||
"""Test that multiple async function calls within the same test function get correctly numbered 0, 1, 2, etc."""
|
||||
source_module_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_sorter(items):
|
||||
"""Simple async sorter for testing."""
|
||||
await asyncio.sleep(0.001)
|
||||
return sorted(items)
|
||||
'''
|
||||
|
||||
source_file = temp_dir / "async_sorter.py"
|
||||
source_file.write_text(source_module_code)
|
||||
|
||||
test_code_multiple_calls = """
|
||||
import asyncio
|
||||
import pytest
|
||||
from async_sorter import async_sorter
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_call():
|
||||
result = await async_sorter([42])
|
||||
assert result == [42]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_calls():
|
||||
result1 = await async_sorter([3, 1, 2])
|
||||
result2 = await async_sorter([5, 4])
|
||||
result3 = await async_sorter([9, 8, 7, 6])
|
||||
assert result1 == [1, 2, 3]
|
||||
assert result2 == [4, 5]
|
||||
assert result3 == [6, 7, 8, 9]
|
||||
"""
|
||||
|
||||
test_file = temp_dir / "test_async_sorter.py"
|
||||
test_file.write_text(test_code_multiple_calls)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_sorter", parents=[], file_path=Path("async_sorter.py"), is_async=True
|
||||
)
|
||||
|
||||
# First instrument the source module with async decorators
|
||||
from codeflash.code_utils.instrument_existing_tests import instrument_source_module_with_async_decorators
|
||||
|
||||
source_success, instrumented_source = instrument_source_module_with_async_decorators(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert source_success
|
||||
assert instrumented_source is not None
|
||||
assert "@codeflash_behavior_async" in instrumented_source
|
||||
|
||||
source_file.write_text(instrumented_source)
|
||||
|
||||
import ast
|
||||
|
||||
tree = ast.parse(test_code_multiple_calls)
|
||||
call_positions = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Await) and isinstance(node.value, ast.Call):
|
||||
if (hasattr(node.value.func, "id") and node.value.func.id == "async_sorter") or (
|
||||
hasattr(node.value.func, "attr") and node.value.func.attr == "async_sorter"
|
||||
):
|
||||
call_positions.append(CodePosition(node.lineno, node.col_offset))
|
||||
|
||||
assert len(call_positions) == 4
|
||||
|
||||
success, instrumented_test_code = inject_profiling_into_existing_test(
|
||||
test_file, call_positions, func, temp_dir, "pytest", mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert success
|
||||
assert instrumented_test_code is not None
|
||||
|
||||
assert "os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'" in instrumented_test_code
|
||||
|
||||
# Count occurrences of each line_id to verify numbering
|
||||
line_id_0_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '0'")
|
||||
line_id_1_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '1'")
|
||||
line_id_2_count = instrumented_test_code.count("os.environ['CODEFLASH_CURRENT_LINE_ID'] = '2'")
|
||||
|
||||
|
||||
assert line_id_0_count == 2, f"Expected 2 occurrences of line_id '0', got {line_id_0_count}"
|
||||
assert line_id_1_count == 1, f"Expected 1 occurrence of line_id '1', got {line_id_1_count}"
|
||||
assert line_id_2_count == 1, f"Expected 1 occurrence of line_id '2', got {line_id_2_count}"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_behavior_decorator_return_values_and_test_ids():
|
||||
"""Test that async behavior decorator correctly captures return values, test IDs, and stores data in database."""
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def test_async_multiply(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.001) # Small delay to simulate async work
|
||||
return x * y
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_module",
|
||||
"CODEFLASH_TEST_CLASS": None,
|
||||
"CODEFLASH_TEST_FUNCTION": "test_async_multiply_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "0",
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_ITERATION": "2",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
result = asyncio.run(test_async_multiply(6, 7))
|
||||
|
||||
assert result == 42, f"Expected return value 42, got {result}"
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import get_run_tmp_file
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_2.sqlite"))
|
||||
|
||||
# Verify database exists and has data
|
||||
assert db_path.exists(), f"Database file not created at {db_path}"
|
||||
|
||||
# Read and verify database contents
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute("SELECT * FROM test_results")
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 1, f"Expected 1 database row, got {len(rows)}"
|
||||
|
||||
row = rows[0]
|
||||
(
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) = row
|
||||
|
||||
assert test_module == "test_module", f"Expected test_module 'test_module', got '{test_module}'"
|
||||
assert test_class is None, f"Expected test_class None, got '{test_class}'"
|
||||
assert test_function == "test_async_multiply_function", (
|
||||
f"Expected test_function 'test_async_multiply_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "test_async_multiply", (
|
||||
f"Expected function_name 'test_async_multiply', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 1, f"Expected loop_index 1, got {loop_index}"
|
||||
assert iteration_id == "0_0", f"Expected iteration_id '0_0', got '{iteration_id}'"
|
||||
assert verification_type == "function_call", (
|
||||
f"Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
unpickled_data = pickle.loads(return_value_blob)
|
||||
args, kwargs, actual_return_value = unpickled_data
|
||||
|
||||
assert args == (6, 7), f"Expected args (6, 7), got {args}"
|
||||
assert kwargs == {}, f"Expected empty kwargs, got {kwargs}"
|
||||
|
||||
assert actual_return_value == 42, f"Expected stored return value 42, got {actual_return_value}"
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_comprehensive_return_values_and_test_ids():
|
||||
import asyncio
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_behavior_async, get_run_tmp_file
|
||||
|
||||
@codeflash_behavior_async
|
||||
async def async_multiply_add(x: int, y: int, z: int = 1) -> int:
|
||||
"""Async function that multiplies x*y then adds z."""
|
||||
await asyncio.sleep(0.001)
|
||||
result = (x * y) + z
|
||||
return result
|
||||
|
||||
test_env = {
|
||||
"CODEFLASH_TEST_MODULE": "test_comprehensive_module",
|
||||
"CODEFLASH_TEST_CLASS": "AsyncTestClass",
|
||||
"CODEFLASH_TEST_FUNCTION": "test_comprehensive_async_function",
|
||||
"CODEFLASH_CURRENT_LINE_ID": "3",
|
||||
"CODEFLASH_LOOP_INDEX": "2",
|
||||
"CODEFLASH_TEST_ITERATION": "3",
|
||||
}
|
||||
|
||||
original_env = {k: os.environ.get(k) for k in test_env}
|
||||
for k, v in test_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
||||
try:
|
||||
test_cases = [
|
||||
{"args": (5, 3), "kwargs": {}, "expected": 16}, # (5 * 3) + 1 = 16
|
||||
{"args": (2, 4), "kwargs": {"z": 10}, "expected": 18}, # (2 * 4) + 10 = 18
|
||||
{"args": (7, 6), "kwargs": {}, "expected": 43}, # (7 * 6) + 1 = 43
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_case in test_cases:
|
||||
result = asyncio.run(async_multiply_add(*test_case["args"], **test_case["kwargs"]))
|
||||
results.append(result)
|
||||
|
||||
# Verify each return value is exactly correct
|
||||
assert result == test_case["expected"], (
|
||||
f"Expected {test_case['expected']}, got {result} for args {test_case['args']}, kwargs {test_case['kwargs']}"
|
||||
)
|
||||
|
||||
db_path = get_run_tmp_file(Path(f"test_return_values_3.sqlite"))
|
||||
assert db_path.exists(), f"Database not created at {db_path}"
|
||||
|
||||
con = sqlite3.connect(db_path)
|
||||
cur = con.cursor()
|
||||
|
||||
cur.execute(
|
||||
"SELECT test_module_path, test_class_name, test_function_name, function_getting_tested, loop_index, iteration_id, runtime, return_value, verification_type FROM test_results ORDER BY rowid"
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
assert len(rows) == 3, f"Expected 3 database rows, got {len(rows)}"
|
||||
|
||||
for i, (
|
||||
test_module,
|
||||
test_class,
|
||||
test_function,
|
||||
function_name,
|
||||
loop_index,
|
||||
iteration_id,
|
||||
runtime,
|
||||
return_value_blob,
|
||||
verification_type,
|
||||
) in enumerate(rows):
|
||||
assert test_module == "test_comprehensive_module", (
|
||||
f"Row {i}: Expected test_module 'test_comprehensive_module', got '{test_module}'"
|
||||
)
|
||||
assert test_class == "AsyncTestClass", f"Row {i}: Expected test_class 'AsyncTestClass', got '{test_class}'"
|
||||
assert test_function == "test_comprehensive_async_function", (
|
||||
f"Row {i}: Expected test_function 'test_comprehensive_async_function', got '{test_function}'"
|
||||
)
|
||||
assert function_name == "async_multiply_add", (
|
||||
f"Row {i}: Expected function_name 'async_multiply_add', got '{function_name}'"
|
||||
)
|
||||
assert loop_index == 2, f"Row {i}: Expected loop_index 2, got {loop_index}"
|
||||
assert verification_type == "function_call", (
|
||||
f"Row {i}: Expected verification_type 'function_call', got '{verification_type}'"
|
||||
)
|
||||
|
||||
expected_iteration_id = f"3_{i}"
|
||||
assert iteration_id == expected_iteration_id, (
|
||||
f"Row {i}: Expected iteration_id '{expected_iteration_id}', got '{iteration_id}'"
|
||||
)
|
||||
|
||||
|
||||
args, kwargs, actual_return_value = pickle.loads(return_value_blob)
|
||||
expected_args = test_cases[i]["args"]
|
||||
expected_kwargs = test_cases[i]["kwargs"]
|
||||
expected_return = test_cases[i]["expected"]
|
||||
|
||||
assert args == expected_args, f"Row {i}: Expected args {expected_args}, got {args}"
|
||||
assert kwargs == expected_kwargs, f"Row {i}: Expected kwargs {expected_kwargs}, got {kwargs}"
|
||||
assert actual_return_value == expected_return, (
|
||||
f"Row {i}: Expected return value {expected_return}, got {actual_return_value}"
|
||||
)
|
||||
|
||||
con.close()
|
||||
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is not None:
|
||||
os.environ[k] = v
|
||||
elif k in os.environ:
|
||||
del os.environ[k]
|
||||
|
|
@ -22,7 +22,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass(ParentClass):
|
||||
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
|
|
@ -197,7 +197,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class HelperClass:
|
||||
|
||||
@codeflash_capture(function_name='HelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
|
||||
@codeflash_capture(function_name='HelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
|
||||
|
|
@ -271,7 +271,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class MyClass:
|
||||
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=True)
|
||||
@codeflash_capture(function_name='MyClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
|
|
@ -289,7 +289,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class HelperClass1:
|
||||
|
||||
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
|
||||
@codeflash_capture(function_name='HelperClass1.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
|
||||
|
|
@ -304,7 +304,7 @@ from codeflash.verification.codeflash_capture import codeflash_capture
|
|||
|
||||
class HelperClass2:
|
||||
|
||||
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
|
||||
@codeflash_capture(function_name='HelperClass2.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
|
||||
def __init__(self):
|
||||
self.z = 2
|
||||
|
||||
|
|
@ -313,7 +313,7 @@ class HelperClass2:
|
|||
|
||||
class AnotherHelperClass:
|
||||
|
||||
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values"))!s}', tests_root='{test_path.parent!s}', is_fto=False)
|
||||
@codeflash_capture(function_name='AnotherHelperClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def test_add_decorator_imports_helper_in_class():
|
|||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ def test_add_decorator_imports_helper_in_nested_class():
|
|||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ def test_add_decorator_imports_nodeps():
|
|||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
||||
@codeflash_line_profile
|
||||
|
|
@ -200,7 +200,7 @@ def test_add_decorator_imports_helper_outside():
|
|||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
||||
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
||||
|
|
@ -275,7 +275,7 @@ class helper:
|
|||
line_profiler_output_file = add_decorator_imports(
|
||||
func_optimizer.function_to_optimize, code_context)
|
||||
expected_code_main = f"""from line_profiler import profile as codeflash_line_profile
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}')
|
||||
codeflash_line_profile.enable(output_prefix='{line_profiler_output_file.as_posix()}')
|
||||
|
||||
|
||||
@codeflash_line_profile
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.instrument_existing_tests import (
|
||||
FunctionImportedAsVisitor,
|
||||
|
|
@ -24,6 +24,8 @@ from codeflash.models.models import (
|
|||
TestsInFile,
|
||||
TestType,
|
||||
)
|
||||
import platform
|
||||
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -87,7 +89,38 @@ codeflash_wrap_perfonly_string = """def codeflash_wrap(codeflash_wrapped, codefl
|
|||
"""
|
||||
|
||||
|
||||
def test_perfinjector_bubble_sort() -> None:
|
||||
def build_expected_unittest_imports(extra_imports: str = "") -> str:
|
||||
imports = """import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import dill as pickle"""
|
||||
if platform.system() != "Windows":
|
||||
imports += "\nimport timeout_decorator"
|
||||
if extra_imports:
|
||||
imports += "\n" + extra_imports
|
||||
return imports
|
||||
|
||||
|
||||
def build_expected_pytest_imports(extra_imports: str = "") -> str:
|
||||
"""Helper to build platform-aware imports for pytest tests."""
|
||||
imports = """import gc
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest"""
|
||||
if extra_imports:
|
||||
imports += "\n" + extra_imports
|
||||
return imports
|
||||
# create a temporary directory for the test results
|
||||
@pytest.fixture
|
||||
def tmp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
yield Path(tmpdirname)
|
||||
|
||||
def test_perfinjector_bubble_sort(tmp_dir) -> None:
|
||||
code = """import unittest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
|
@ -106,54 +139,27 @@ class TestPigLatin(unittest.TestCase):
|
|||
input = list(reversed(range(5000)))
|
||||
self.assertEqual(sorter(input), list(range(5000)))
|
||||
"""
|
||||
expected = """import gc
|
||||
imports = """import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import dill as pickle
|
||||
import timeout_decorator
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
|
||||
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
|
||||
if not hasattr(codeflash_wrap, 'index'):
|
||||
codeflash_wrap.index = {{}}
|
||||
if test_id in codeflash_wrap.index:
|
||||
codeflash_wrap.index[test_id] += 1
|
||||
else:
|
||||
codeflash_wrap.index[test_id] = 0
|
||||
codeflash_test_index = codeflash_wrap.index[test_id]
|
||||
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
|
||||
"""
|
||||
expected += """test_stdout_tag = f'{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}'
|
||||
"""
|
||||
expected += """print(f'!$######{{test_stdout_tag}}######$!')
|
||||
exception = None
|
||||
gc.disable()
|
||||
try:
|
||||
counter = time.perf_counter_ns()
|
||||
return_value = codeflash_wrapped(*args, **kwargs)
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
except Exception as e:
|
||||
codeflash_duration = time.perf_counter_ns() - counter
|
||||
exception = e
|
||||
gc.enable()
|
||||
print(f'!######{{test_stdout_tag}}######!')
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
|
||||
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
|
||||
codeflash_con.commit()
|
||||
if exception:
|
||||
raise exception
|
||||
return return_value
|
||||
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self):
|
||||
import dill as pickle"""
|
||||
if platform.system() != "Windows":
|
||||
imports += "\nimport timeout_decorator"
|
||||
|
||||
imports += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
wrapper_func = codeflash_wrap_string
|
||||
|
||||
test_class_header = "class TestPigLatin(unittest.TestCase):"
|
||||
test_decorator = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
|
||||
expected = imports + "\n\n\n" + wrapper_func + "\n" + test_class_header + "\n\n"
|
||||
if test_decorator:
|
||||
expected += test_decorator + "\n"
|
||||
expected += """ def test_sort(self):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
|
|
@ -169,7 +175,8 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input), list(range(5000)))
|
||||
codeflash_con.close()
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
|
||||
with (tmp_dir / "test_sort.py").open("w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path(f.name))
|
||||
|
|
@ -186,11 +193,11 @@ class TestPigLatin(unittest.TestCase):
|
|||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
module_path=Path(f.name).stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
|
||||
def test_perfinjector_only_replay_test() -> None:
|
||||
def test_perfinjector_only_replay_test(tmp_dir) -> None:
|
||||
code = """import dill as pickle
|
||||
import pytest
|
||||
from codeflash.tracing.replay_test import get_next_arg_and_return
|
||||
|
|
@ -269,7 +276,7 @@ def test_prepare_image_for_yolo():
|
|||
assert compare_results(return_val_1, ret)
|
||||
codeflash_con.close()
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
with (tmp_dir / "test_return_values.py").open("w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
func = FunctionToOptimize(function_name="prepare_image_for_yolo", parents=[], file_path=Path("module.py"))
|
||||
|
|
@ -282,7 +289,7 @@ def test_prepare_image_for_yolo():
|
|||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
module_path=Path(f.name).stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
|
||||
|
|
@ -389,7 +396,7 @@ def test_sort():
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
success, new_perf_test = inject_profiling_into_existing_test(
|
||||
|
|
@ -404,7 +411,7 @@ def test_sort():
|
|||
assert new_perf_test is not None
|
||||
assert new_perf_test.replace('"', "'") == expected_perfonly.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
with test_path.open("w") as f:
|
||||
|
|
@ -532,7 +539,8 @@ result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
|
|||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
tmp_lpr = list(line_profile_results["timings"].keys())
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 2
|
||||
if sys.platform != "win32":
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 2
|
||||
finally:
|
||||
if computed_fn_opt:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
|
|
@ -643,11 +651,11 @@ def test_sort_parametrized(input, expected_output):
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perfonly.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
#
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -799,7 +807,8 @@ result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
|
|||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
tmp_lpr = list(line_profile_results["timings"].keys())
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
|
||||
if sys.platform != "win32":
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
|
||||
finally:
|
||||
if computed_fn_opt:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
|
|
@ -916,7 +925,7 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -925,7 +934,7 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -1154,7 +1163,8 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
|
|||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
tmp_lpr = list(line_profile_results["timings"].keys())
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 6
|
||||
if sys.platform != "win32":
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 6
|
||||
finally:
|
||||
if computed_fn_opt:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
|
|
@ -1271,12 +1281,12 @@ def test_sort():
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -1432,7 +1442,8 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
|
|||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
tmp_lpr = list(line_profile_results["timings"].keys())
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
|
||||
if sys.platform != "win32":
|
||||
assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1] == 3
|
||||
finally:
|
||||
if computed_fn_opt is True:
|
||||
func_optimizer.write_code_and_helpers(
|
||||
|
|
@ -1446,6 +1457,7 @@ result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 2
|
|||
|
||||
|
||||
def test_perfinjector_bubble_sort_unittest_results() -> None:
|
||||
|
||||
code = """import unittest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
|
@ -1466,8 +1478,74 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, list(range(50)))
|
||||
"""
|
||||
|
||||
expected = (
|
||||
"""import gc
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
if is_windows:
|
||||
expected = (
|
||||
"""import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import dill as pickle
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
def test_sort(self):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
codeflash_cur = codeflash_con.cursor()
|
||||
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input)
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input)
|
||||
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
input = list(reversed(range(50)))
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input)
|
||||
self.assertEqual(output, list(range(50)))
|
||||
codeflash_con.close()
|
||||
"""
|
||||
)
|
||||
expected_perf = (
|
||||
"""import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
def test_sort(self):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, input)
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, input)
|
||||
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
input = list(reversed(range(50)))
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input)
|
||||
self.assertEqual(output, list(range(50)))
|
||||
"""
|
||||
)
|
||||
else:
|
||||
expected = (
|
||||
"""import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
|
|
@ -1480,8 +1558,8 @@ from code_to_optimize.bubble_sort import sorter
|
|||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_string
|
||||
+ """
|
||||
+ codeflash_wrap_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@timeout_decorator.timeout(15)
|
||||
|
|
@ -1502,9 +1580,9 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, list(range(50)))
|
||||
codeflash_con.close()
|
||||
"""
|
||||
)
|
||||
expected_perf = (
|
||||
"""import gc
|
||||
)
|
||||
expected_perf = (
|
||||
"""import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
|
@ -1515,8 +1593,8 @@ from code_to_optimize.bubble_sort import sorter
|
|||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@timeout_decorator.timeout(15)
|
||||
|
|
@ -1532,7 +1610,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input)
|
||||
self.assertEqual(output, list(range(50)))
|
||||
"""
|
||||
)
|
||||
)
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
|
|
@ -1580,11 +1658,11 @@ class TestPigLatin(unittest.TestCase):
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
#
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -1744,28 +1822,18 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
|
||||
expected_behavior = (
|
||||
"""import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import dill as pickle
|
||||
import timeout_decorator
|
||||
from parameterized import parameterized
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
# Build expected behavior output with platform-aware imports
|
||||
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
|
||||
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class_behavior = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self, input, expected_output):
|
||||
"""
|
||||
if test_decorator_behavior:
|
||||
test_class_behavior += test_decorator_behavior + "\n"
|
||||
test_class_behavior += """ def test_sort(self, input, expected_output):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
|
|
@ -1775,32 +1843,32 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, expected_output)
|
||||
codeflash_con.close()
|
||||
"""
|
||||
)
|
||||
expected_perf = (
|
||||
"""import gc
|
||||
|
||||
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
|
||||
# Build expected perf output with platform-aware imports
|
||||
imports_perf = """import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import timeout_decorator
|
||||
from parameterized import parameterized
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
if platform.system() != "Windows":
|
||||
imports_perf += "\nimport timeout_decorator"
|
||||
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class_perf = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self, input, expected_output):
|
||||
"""
|
||||
if test_decorator_perf:
|
||||
test_class_perf += test_decorator_perf + "\n"
|
||||
test_class_perf += """ def test_sort(self, input, expected_output):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, input)
|
||||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
)
|
||||
|
||||
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
|
|
@ -1837,13 +1905,13 @@ class TestPigLatin(unittest.TestCase):
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
assert new_test_perf is not None
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
|
||||
#
|
||||
|
|
@ -2003,26 +2071,17 @@ class TestPigLatin(unittest.TestCase):
|
|||
output = sorter(input)
|
||||
self.assertEqual(output, expected_output)"""
|
||||
|
||||
expected_behavior = (
|
||||
"""import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import dill as pickle
|
||||
import timeout_decorator
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
# Build expected behavior output with platform-aware imports
|
||||
imports_behavior = build_expected_unittest_imports()
|
||||
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class_behavior = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self):
|
||||
if test_decorator_behavior:
|
||||
test_class_behavior += test_decorator_behavior + "\n"
|
||||
test_class_behavior += """ def test_sort(self):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
|
|
@ -2037,26 +2096,28 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, expected_output)
|
||||
codeflash_con.close()
|
||||
"""
|
||||
)
|
||||
|
||||
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
|
||||
|
||||
expected_perf = (
|
||||
"""import gc
|
||||
# Build expected perf output with platform-aware imports
|
||||
imports_perf = """import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import timeout_decorator
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
"""
|
||||
if platform.system() != "Windows":
|
||||
imports_perf += "\nimport timeout_decorator"
|
||||
imports_perf += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
else:
|
||||
imports_perf += "\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class_perf = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self):
|
||||
if test_decorator_perf:
|
||||
test_class_perf += test_decorator_perf + "\n"
|
||||
test_class_perf += """ def test_sort(self):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
inputs = [[5, 4, 3, 2, 1, 0], [5.0, 4.0, 3.0, 2.0, 1.0, 0.0], list(reversed(range(50)))]
|
||||
expected_outputs = [[0, 1, 2, 3, 4, 5], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], list(range(50))]
|
||||
|
|
@ -2066,7 +2127,8 @@ class TestPigLatin(unittest.TestCase):
|
|||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, input)
|
||||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
)
|
||||
|
||||
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
|
|
@ -2103,11 +2165,11 @@ class TestPigLatin(unittest.TestCase):
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
#
|
||||
# # Overwrite old test with new instrumented test
|
||||
|
|
@ -2269,28 +2331,18 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
|
||||
expected_behavior = (
|
||||
"""import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import dill as pickle
|
||||
import timeout_decorator
|
||||
from parameterized import parameterized
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
# Build expected behavior output with platform-aware imports
|
||||
imports_behavior = build_expected_unittest_imports("from parameterized import parameterized")
|
||||
imports_behavior += "\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
test_decorator_behavior = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class_behavior = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self, input, expected_output):
|
||||
"""
|
||||
if test_decorator_behavior:
|
||||
test_class_behavior += test_decorator_behavior + "\n"
|
||||
test_class_behavior += """ def test_sort(self, input, expected_output):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
|
||||
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
|
||||
|
|
@ -2301,33 +2353,33 @@ class TestPigLatin(unittest.TestCase):
|
|||
self.assertEqual(output, expected_output)
|
||||
codeflash_con.close()
|
||||
"""
|
||||
)
|
||||
expected_perf = (
|
||||
"""import gc
|
||||
|
||||
expected_behavior = imports_behavior + "\n\n\n" + codeflash_wrap_string + "\n" + test_class_behavior
|
||||
# Build expected perf output with platform-aware imports
|
||||
imports_perf = """import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import timeout_decorator
|
||||
from parameterized import parameterized
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
if platform.system() != "Windows":
|
||||
imports_perf += "\nimport timeout_decorator"
|
||||
imports_perf += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.bubble_sort import sorter"
|
||||
|
||||
test_decorator_perf = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class_perf = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))])
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sort(self, input, expected_output):
|
||||
"""
|
||||
if test_decorator_perf:
|
||||
test_class_perf += test_decorator_perf + "\n"
|
||||
test_class_perf += """ def test_sort(self, input, expected_output):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
for i in range(2):
|
||||
output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, input)
|
||||
self.assertEqual(output, expected_output)
|
||||
"""
|
||||
)
|
||||
|
||||
expected_perf = imports_perf + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class_perf
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
|
|
@ -2362,11 +2414,11 @@ class TestPigLatin(unittest.TestCase):
|
|||
assert new_test_behavior is not None
|
||||
assert new_test_behavior.replace('"', "'") == expected_behavior.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
assert new_test_perf.replace('"', "'") == expected_perf.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
#
|
||||
# Overwrite old test with new instrumented test
|
||||
|
|
@ -2663,7 +2715,7 @@ def test_class_name_A_function_name():
|
|||
assert success
|
||||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
module_path="tests.pytest.test_class_function_instrumentation_temp",
|
||||
).replace('"', "'")
|
||||
|
||||
|
|
@ -2734,7 +2786,7 @@ def test_common_tags_1():
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="tests.pytest.test_wrong_function_instrumentation_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
|
@ -2797,7 +2849,7 @@ def test_sort():
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="tests.pytest.test_conditional_instrumentation_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
).replace('"', "'")
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
|
@ -2874,7 +2926,7 @@ def test_sort():
|
|||
assert success
|
||||
formatted_expected = expected.format(
|
||||
module_path="tests.pytest.test_perfinjector_bubble_sort_results_temp",
|
||||
tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
)
|
||||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == formatted_expected.replace('"', "'")
|
||||
|
|
@ -2882,7 +2934,7 @@ def test_sort():
|
|||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_class_method_instrumentation() -> None:
|
||||
def test_class_method_instrumentation(tmp_path: Path) -> None:
|
||||
code = """from codeflash.optimization.optimizer import Optimizer
|
||||
def test_code_replacement10() -> None:
|
||||
get_code_output = '''random code'''
|
||||
|
|
@ -2952,24 +3004,24 @@ def test_code_replacement10() -> None:
|
|||
"""
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
func = FunctionToOptimize(
|
||||
function_name="get_code_optimization_context",
|
||||
parents=[FunctionParent("Optimizer", "ClassDef")],
|
||||
file_path=Path(f.name),
|
||||
)
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
Path(f.name), [CodePosition(22, 28), CodePosition(28, 28)], func, Path(f.name).parent, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
test_file_path = tmp_path / "test_class_method_instrumentation.py"
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="get_code_optimization_context",
|
||||
parents=[FunctionParent("Optimizer", "ClassDef")],
|
||||
file_path=test_file_path,
|
||||
)
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_file_path, [CodePosition(22, 28), CodePosition(28, 28)], func, test_file_path.parent, "pytest"
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test.replace('"', "'") == expected.replace('"', "'").format(
|
||||
module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values"))
|
||||
module_path=test_file_path.stem, tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -3034,7 +3086,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.pytest.test_time_correction_instrumentation_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
# Overwrite old test with new instrumented test
|
||||
with test_path.open("w") as f:
|
||||
|
|
@ -3101,30 +3153,29 @@ class TestPigLatin(unittest.TestCase):
|
|||
output = accurate_sleepfunc(n)
|
||||
"""
|
||||
|
||||
expected = (
|
||||
"""import gc
|
||||
# Build expected output with platform-aware imports
|
||||
imports = """import gc
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import timeout_decorator
|
||||
from parameterized import parameterized
|
||||
|
||||
from code_to_optimize.sleeptime import accurate_sleepfunc
|
||||
|
||||
|
||||
"""
|
||||
+ codeflash_wrap_perfonly_string
|
||||
+ """
|
||||
class TestPigLatin(unittest.TestCase):
|
||||
if platform.system() != "Windows":
|
||||
imports += "\nimport timeout_decorator"
|
||||
imports += "\nfrom parameterized import parameterized\n\nfrom code_to_optimize.sleeptime import accurate_sleepfunc"
|
||||
|
||||
test_decorator = " @timeout_decorator.timeout(15)" if platform.system() != "Windows" else ""
|
||||
test_class = """class TestPigLatin(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([(0.01, 0.01), (0.02, 0.02)])
|
||||
@timeout_decorator.timeout(15)
|
||||
def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time):
|
||||
"""
|
||||
if test_decorator:
|
||||
test_class += test_decorator + "\n"
|
||||
test_class += """ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time):
|
||||
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
|
||||
output = codeflash_wrap(accurate_sleepfunc, '{module_path}', 'TestPigLatin', 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, n)
|
||||
"""
|
||||
)
|
||||
|
||||
expected = imports + "\n\n\n" + codeflash_wrap_perfonly_string + "\n" + test_class
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sleeptime.py").resolve()
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
|
|
@ -3153,7 +3204,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
assert new_test is not None
|
||||
assert new_test.replace('"', "'") == expected.format(
|
||||
module_path="code_to_optimize.tests.unittest.test_time_correction_instrumentation_unittest_temp",
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")),
|
||||
tmp_dir_path=get_run_tmp_file(Path("test_return_values")).as_posix(),
|
||||
).replace('"', "'")
|
||||
# Overwrite old test with new instrumented test
|
||||
with test_path.open("w") as f:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|||
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
import time
|
||||
|
||||
try:
|
||||
import sqlalchemy
|
||||
|
|
@ -156,6 +157,9 @@ def test_picklepatch_with_database_connection():
|
|||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["connection"].execute("SELECT 1")
|
||||
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_picklepatch_with_generator():
|
||||
"""Test that a data structure containing a generator is replaced by
|
||||
|
|
@ -287,17 +291,26 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
|
||||
|
||||
# Close the connection to allow file cleanup on Windows
|
||||
conn.close()
|
||||
time.sleep(1)
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
test_name, total_time, function_time, percent = \
|
||||
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
# Handle the case where function runs too fast to be measured
|
||||
unused_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"]
|
||||
if unused_socket_results:
|
||||
test_name, total_time, function_time, percent = unused_socket_results[0]
|
||||
assert total_time >= 0.0
|
||||
# Function might be too fast, so we allow 0.0 function_time
|
||||
assert function_time >= 0.0
|
||||
assert percent >= 0.0
|
||||
used_socket_results = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_used_socket.bubble_sort_with_used_socket"]
|
||||
# on windows , if the socket is not used we might not have resultssss
|
||||
if used_socket_results:
|
||||
test_name, total_time, function_time, percent = used_socket_results[0]
|
||||
assert total_time >= 0.0
|
||||
assert function_time >= 0.0
|
||||
assert percent >= 0.0
|
||||
|
||||
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
|
||||
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
|
||||
|
|
@ -318,7 +331,9 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
conn.close()
|
||||
conn.close()
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# Generate replay test
|
||||
generate_replay_test(output_file, replay_tests_dir)
|
||||
|
|
@ -510,4 +525,3 @@ def bubble_sort_with_used_socket(data_container):
|
|||
shutil.rmtree(replay_tests_dir, ignore_errors=True)
|
||||
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
|
||||
fto_used_socket_path.write_text(original_fto_used_socket_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,12 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
self.test_rc_path = "test_shell_rc"
|
||||
self.api_key = "cf-1234567890abcdef"
|
||||
os.environ["SHELL"] = "/bin/bash" # Set a default shell for testing
|
||||
|
||||
# Set up platform-specific export syntax
|
||||
if os.name == "nt": # Windows
|
||||
self.api_key_export = f'set CODEFLASH_API_KEY={self.api_key}'
|
||||
else: # Unix-like systems
|
||||
self.api_key_export = f'export CODEFLASH_API_KEY="{self.api_key}"'
|
||||
|
||||
def tearDown(self):
|
||||
"""Cleanup the temporary shell configuration file after testing."""
|
||||
|
|
@ -50,25 +56,37 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path:
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
"builtins.open", mock_open(read_data=f'{self.api_key_export}\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'#export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), None)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY={self.api_key}\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
if os.name != "nt":
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'#export CODEFLASH_API_KEY=\'{self.api_key}\'\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), None)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY={self.api_key}\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
elif os.name == "nt":
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'REM set CODEFLASH_API_KEY={self.api_key}\n')
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), None)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
|
|
@ -83,23 +101,41 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
def test_malformed_api_key_export(self, mock_get_shell_rc_path):
|
||||
"""Test with a malformed API key export."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch("builtins.open", mock_open(read_data=f"export API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch("builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
|
||||
if os.name == "nt":
|
||||
with patch("builtins.open", mock_open(read_data=f"set API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch("builtins.open", mock_open(read_data=f"set CODEFLASH_API_KEY=sk-{self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
else:
|
||||
with patch("builtins.open", mock_open(read_data=f"export API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch("builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_multiple_api_key_exports(self, mock_get_shell_rc_path):
|
||||
"""Test with multiple API key exports."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
if os.name == "nt": # Windows
|
||||
first_export = 'set CODEFLASH_API_KEY=cf-firstkey'
|
||||
second_export = f'set CODEFLASH_API_KEY={self.api_key}'
|
||||
else:
|
||||
first_export = 'export CODEFLASH_API_KEY="cf-firstkey"'
|
||||
second_export = f'export CODEFLASH_API_KEY="{self.api_key}"'
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
mock_open(read_data=f'{first_export}\n{second_export}\n'),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
|
|
@ -109,7 +145,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'),
|
||||
mock_open(read_data=f'# Setting API Key\n{self.api_key_export}\n# Done\n'),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
|
|
@ -117,7 +153,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
def test_api_key_in_comment(self, mock_get_shell_rc_path):
|
||||
"""Test with API key export in a comment."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch("builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')):
|
||||
with patch("builtins.open", mock_open(read_data=f'# {self.api_key_export}\n')):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
|
|
|
|||
|
|
@ -34,12 +34,12 @@ class TestUnittestRunnerSorter(unittest.TestCase):
|
|||
tests_project_rootdir=cur_dir_path.parent,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
|
||||
test_file_path = Path(temp_dir) / "test_xx.py"
|
||||
test_files = TestFiles(
|
||||
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
fp.write(code.encode("utf-8"))
|
||||
fp.flush()
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
|
|
@ -78,12 +78,12 @@ def test_sort():
|
|||
else:
|
||||
test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
|
||||
test_file_path = Path(temp_dir) / "test_xx.py"
|
||||
test_files = TestFiles(
|
||||
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
fp.write(code.encode("utf-8"))
|
||||
fp.flush()
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
|
|
@ -125,12 +125,12 @@ def test_sort():
|
|||
else:
|
||||
test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
with tempfile.TemporaryDirectory(dir=cur_dir_path) as temp_dir:
|
||||
test_file_path = Path(temp_dir) / "test_xx.py"
|
||||
test_files = TestFiles(
|
||||
test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
fp.write(code.encode("utf-8"))
|
||||
fp.flush()
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
|||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||
import time
|
||||
|
||||
|
||||
def test_trace_benchmarks() -> None:
|
||||
|
|
@ -154,7 +155,7 @@ from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
|||
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
|
||||
functions = ['compute_and_sort', 'sorter']
|
||||
trace_file_path = r"{output_file}"
|
||||
trace_file_path = r"{output_file.as_posix()}"
|
||||
|
||||
def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort_test_compute_and_sort():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100):
|
||||
|
|
@ -196,6 +197,8 @@ def test_trace_multithreaded_benchmark() -> None:
|
|||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
conn.close()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
|
|
@ -204,9 +207,9 @@ def test_trace_multithreaded_benchmark() -> None:
|
|||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
assert total_time >= 0.0
|
||||
assert function_time >= 0.0
|
||||
assert percent >= 0.0
|
||||
|
||||
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
||||
# Expected function calls
|
||||
|
|
@ -279,8 +282,10 @@ def test_trace_benchmark_decorator() -> None:
|
|||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
# Close connection
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
time.sleep(2)
|
||||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
time.sleep(1)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import dataclasses
|
|||
import pickle
|
||||
import sqlite3
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
|
@ -59,29 +58,26 @@ class TraceConfig:
|
|||
|
||||
class TestTracer:
|
||||
@pytest.fixture
|
||||
def trace_config(self) -> Generator[Path, None, None]:
|
||||
def trace_config(self, tmp_path: Path) -> Generator[TraceConfig, None, None]:
|
||||
"""Create a temporary pyproject.toml config file."""
|
||||
# Create a temporary directory structure
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
tests_dir = temp_dir / "tests"
|
||||
tests_dir = tmp_path / "tests"
|
||||
tests_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Use the current working directory as module root so test files are included
|
||||
current_dir = Path.cwd()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".toml", delete=False, dir=temp_dir) as f:
|
||||
f.write(f"""
|
||||
config_path = tmp_path / "pyproject.toml"
|
||||
config_path.write_text(f"""
|
||||
[tool.codeflash]
|
||||
module-root = "{current_dir}"
|
||||
tests-root = "{tests_dir}"
|
||||
module-root = "{current_dir.as_posix()}"
|
||||
tests-root = "{tests_dir.as_posix()}"
|
||||
test-framework = "pytest"
|
||||
ignore-paths = []
|
||||
""")
|
||||
config_path = Path(f.name)
|
||||
with tempfile.NamedTemporaryFile(suffix=".trace", delete=False) as f:
|
||||
trace_path = Path(f.name)
|
||||
trace_path.unlink(missing_ok=True) # Remove the file, we just want the path
|
||||
replay_test_pkl_path = temp_dir / "replay_test.pkl"
|
||||
""", encoding="utf-8")
|
||||
|
||||
trace_path = tmp_path / "trace_file.trace"
|
||||
replay_test_pkl_path = tmp_path / "replay_test.pkl"
|
||||
config, found_config_path = parse_config_file(config_path)
|
||||
trace_config = TraceConfig(
|
||||
trace_file=trace_path,
|
||||
|
|
@ -92,11 +88,6 @@ ignore-paths = []
|
|||
)
|
||||
|
||||
yield trace_config
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
trace_path.unlink(missing_ok=True)
|
||||
replay_test_pkl_path.unlink(missing_ok=True)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_tracer_state(self) -> Generator[None, None, None]:
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
|
|||
assert len(discovered_tests) == 1
|
||||
assert len(discovered_tests["dummy_code.dummy_function"]) == 2
|
||||
dummy_tests = discovered_tests["dummy_code.dummy_function"]
|
||||
assert all(test.tests_in_file.test_file == test_file_path for test in dummy_tests)
|
||||
assert all(test.tests_in_file.test_file.resolve() == test_file_path.resolve() for test in dummy_tests)
|
||||
assert {test.tests_in_file.test_function for test in dummy_tests} == {
|
||||
"test_dummy_parametrized_function[True]",
|
||||
"test_dummy_function",
|
||||
|
|
@ -204,16 +204,13 @@ def test_discover_tests_pytest_with_multi_level_dirs():
|
|||
|
||||
# Check if the test files at all levels are discovered
|
||||
assert len(discovered_tests) == 3
|
||||
assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path
|
||||
assert (
|
||||
next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
|
||||
== level1_test_file_path
|
||||
)
|
||||
discovered_root_test = next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file
|
||||
assert discovered_root_test.resolve() == root_test_file_path.resolve()
|
||||
discovered_level1_test = next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
|
||||
assert discovered_level1_test.resolve() == level1_test_file_path.resolve()
|
||||
|
||||
assert (
|
||||
next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
|
||||
== level2_test_file_path
|
||||
)
|
||||
discovered_level2_test = next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
|
||||
assert discovered_level2_test.resolve() == level2_test_file_path.resolve()
|
||||
|
||||
|
||||
def test_discover_tests_pytest_dirs():
|
||||
|
|
@ -295,20 +292,15 @@ def test_discover_tests_pytest_dirs():
|
|||
|
||||
# Check if the test files at all levels are discovered
|
||||
assert len(discovered_tests) == 4
|
||||
assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path
|
||||
assert (
|
||||
next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
|
||||
== level1_test_file_path
|
||||
)
|
||||
assert (
|
||||
next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
|
||||
== level2_test_file_path
|
||||
)
|
||||
discovered_root_test = next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file
|
||||
assert discovered_root_test.resolve() == root_test_file_path.resolve()
|
||||
discovered_level1_test = next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file
|
||||
assert discovered_level1_test.resolve() == level1_test_file_path.resolve()
|
||||
discovered_level2_test = next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file
|
||||
assert discovered_level2_test.resolve() == level2_test_file_path.resolve()
|
||||
|
||||
assert (
|
||||
next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file
|
||||
== level3_test_file_path
|
||||
)
|
||||
discovered_level3_test = next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file
|
||||
assert discovered_level3_test.resolve() == level3_test_file_path.resolve()
|
||||
|
||||
|
||||
def test_discover_tests_pytest_with_class():
|
||||
|
|
@ -342,10 +334,8 @@ def test_discover_tests_pytest_with_class():
|
|||
|
||||
# Check if the test class and method are discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert (
|
||||
next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file
|
||||
== test_file_path
|
||||
)
|
||||
discovered_class_test = next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file
|
||||
assert discovered_class_test.resolve() == test_file_path.resolve()
|
||||
|
||||
|
||||
def test_discover_tests_pytest_with_double_nested_directories():
|
||||
|
|
@ -383,12 +373,10 @@ def test_discover_tests_pytest_with_double_nested_directories():
|
|||
|
||||
# Check if the test class and method are discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert (
|
||||
next(
|
||||
iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])
|
||||
).tests_in_file.test_file
|
||||
== test_file_path
|
||||
)
|
||||
discovered_nested_test = next(
|
||||
iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])
|
||||
).tests_in_file.test_file
|
||||
assert discovered_nested_test.resolve() == test_file_path.resolve()
|
||||
|
||||
|
||||
def test_discover_tests_with_code_in_dir_and_test_in_subdir():
|
||||
|
|
@ -433,7 +421,8 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
|
|||
|
||||
# Check if the test file is discovered and associated with the code file
|
||||
assert len(discovered_tests) == 1
|
||||
assert next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file == test_file_path
|
||||
discovered_test_file = next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file
|
||||
assert discovered_test_file.resolve() == test_file_path.resolve()
|
||||
|
||||
|
||||
def test_discover_tests_pytest_with_nested_class():
|
||||
|
|
@ -469,10 +458,8 @@ def test_discover_tests_pytest_with_nested_class():
|
|||
|
||||
# Check if the test for the nested class method is discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert (
|
||||
next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file
|
||||
== test_file_path
|
||||
)
|
||||
discovered_inner_test = next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file
|
||||
assert discovered_inner_test.resolve() == test_file_path.resolve()
|
||||
|
||||
|
||||
def test_discover_tests_pytest_separate_moduledir():
|
||||
|
|
@ -509,7 +496,8 @@ def test_discover_tests_pytest_separate_moduledir():
|
|||
|
||||
# Check if the test for the nested class method is discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file == test_file_path
|
||||
discovered_test_file = next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file
|
||||
assert discovered_test_file.resolve() == test_file_path.resolve()
|
||||
|
||||
|
||||
def test_unittest_discovery_with_pytest():
|
||||
|
|
@ -554,7 +542,7 @@ class TestCalculator(unittest.TestCase):
|
|||
assert "calculator.Calculator.add" in discovered_tests
|
||||
assert len(discovered_tests["calculator.Calculator.add"]) == 1
|
||||
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
|
||||
assert calculator_test.tests_in_file.test_file == test_file_path
|
||||
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
|
||||
assert calculator_test.tests_in_file.test_function == "test_add"
|
||||
|
||||
|
||||
|
|
@ -622,7 +610,7 @@ class TestCalculator(ExtendedTestCase):
|
|||
assert "calculator.Calculator.add" in discovered_tests
|
||||
assert len(discovered_tests["calculator.Calculator.add"]) == 1
|
||||
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
|
||||
assert calculator_test.tests_in_file.test_file == test_file_path
|
||||
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
|
||||
assert calculator_test.tests_in_file.test_function == "test_add"
|
||||
|
||||
|
||||
|
|
@ -720,7 +708,7 @@ class TestCalculator(unittest.TestCase):
|
|||
assert "calculator.Calculator.add" in discovered_tests
|
||||
assert len(discovered_tests["calculator.Calculator.add"]) == 1
|
||||
calculator_test = next(iter(discovered_tests["calculator.Calculator.add"]))
|
||||
assert calculator_test.tests_in_file.test_file == test_file_path
|
||||
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
|
||||
assert calculator_test.tests_in_file.test_function == "test_add_with_parameters"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
from codeflash.context.unused_definition_remover import revert_unused_helper_functions
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -225,6 +227,15 @@ def helper_function_2(x):
|
|||
def test_no_unused_helpers_no_revert(temp_project):
|
||||
"""Test that when all helpers are still used, nothing is reverted."""
|
||||
temp_dir, main_file, test_cfg = temp_project
|
||||
|
||||
|
||||
# Store original content to verify nothing changes
|
||||
original_content = main_file.read_text()
|
||||
|
||||
revert_unused_helper_functions(temp_dir, [], {})
|
||||
|
||||
# Verify the file content remains unchanged
|
||||
assert main_file.read_text() == original_content, "File should remain unchanged when no helpers to revert"
|
||||
|
||||
# Optimized version that still calls both helpers
|
||||
optimized_code = """
|
||||
|
|
@ -308,17 +319,23 @@ def helper_function_1(x):
|
|||
def helper_function_2(x):
|
||||
\"\"\"Second helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def helper_function_1(y): # Duplicate name to test line 575
|
||||
\"\"\"Overloaded helper function.\"\"\"
|
||||
return y + 10
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one helper
|
||||
# Optimized version that only calls one helper with aliased import
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
from helpers import helper_function_1
|
||||
from helpers import helper_function_1 as h1
|
||||
import helpers as h_module
|
||||
|
||||
def entrypoint_function(n):
|
||||
\"\"\"Optimized function that only calls one helper.\"\"\"
|
||||
result1 = helper_function_1(n)
|
||||
return result1 + n * 3 # Inlined helper_function_2
|
||||
\"\"\"Optimized function that only calls one helper with aliasing.\"\"\"
|
||||
result1 = h1(n) # Using aliased import
|
||||
# Inlined helper_function_2 functionality: n * 3
|
||||
return result1 + n * 3 # Fully inlined helper_function_2
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -1460,3 +1477,595 @@ class MathUtils:
|
|||
import shutil
|
||||
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_entrypoint_with_async_helpers():
|
||||
"""Test that unused async helper functions are correctly detected when entrypoint is async."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with async entrypoint and async helpers
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
result2 = await async_helper_2(n)
|
||||
return result1 + result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one async helper
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function - should be unused.\"\"\"
|
||||
return x * 3
|
||||
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
return result1 + n * 3 # Inlined async_helper_2
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="async_entrypoint",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect async_helper_2 as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"async_helper_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_sync_entrypoint_with_async_helpers():
|
||||
"""Test that unused async helper functions are detected when entrypoint is sync."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with sync entrypoint and async helpers
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
import asyncio
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_entrypoint(n):
|
||||
\"\"\"Sync entrypoint function that calls async helpers.\"\"\"
|
||||
result1 = asyncio.run(async_helper_1(n))
|
||||
result2 = asyncio.run(async_helper_2(n))
|
||||
return result1 + result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one async helper
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
import asyncio
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function - should be unused.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_entrypoint(n):
|
||||
\"\"\"Optimized sync entrypoint that only calls one async helper.\"\"\"
|
||||
result1 = asyncio.run(async_helper_1(n))
|
||||
return result1 + n * 3 # Inlined async_helper_2
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for sync function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="sync_entrypoint",
|
||||
parents=[]
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect async_helper_2 as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"async_helper_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_mixed_sync_and_async_helpers():
|
||||
"""Test detection when both sync and async helpers are mixed."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with mixed sync and async helpers
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
import asyncio
|
||||
|
||||
def sync_helper_1(x):
|
||||
\"\"\"Sync helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"Async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_2(x):
|
||||
\"\"\"Another sync helper function.\"\"\"
|
||||
return x * 4
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Another async helper function.\"\"\"
|
||||
return x * 5
|
||||
|
||||
async def mixed_entrypoint(n):
|
||||
\"\"\"Async entrypoint function that calls both sync and async helpers.\"\"\"
|
||||
sync_result = sync_helper_1(n)
|
||||
async_result = await async_helper_1(n)
|
||||
sync_result2 = sync_helper_2(n)
|
||||
async_result2 = await async_helper_2(n)
|
||||
return sync_result + async_result + sync_result2 + async_result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls some helpers
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
import asyncio
|
||||
|
||||
def sync_helper_1(x):
|
||||
\"\"\"Sync helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_1(x):
|
||||
\"\"\"Async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_2(x):
|
||||
\"\"\"Another sync helper function - should be unused.\"\"\"
|
||||
return x * 4
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Another async helper function - should be unused.\"\"\"
|
||||
return x * 5
|
||||
|
||||
async def mixed_entrypoint(n):
|
||||
\"\"\"Optimized async entrypoint that only calls some helpers.\"\"\"
|
||||
sync_result = sync_helper_1(n)
|
||||
async_result = await async_helper_1(n)
|
||||
return sync_result + async_result + n * 4 + n * 5 # Inlined both helper_2 functions
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="mixed_entrypoint",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect both sync_helper_2 and async_helper_2 as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"sync_helper_2", "async_helper_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_class_methods():
|
||||
"""Test unused async method detection in classes."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with class containing async methods
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
class AsyncProcessor:
|
||||
async def entrypoint_method(self, n):
|
||||
\"\"\"Async main method that calls async helper methods.\"\"\"
|
||||
result1 = await self.async_helper_method_1(n)
|
||||
result2 = await self.async_helper_method_2(n)
|
||||
return result1 + result2
|
||||
|
||||
async def async_helper_method_1(self, x):
|
||||
\"\"\"First async helper method.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_method_2(self, x):
|
||||
\"\"\"Second async helper method.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_method(self, x):
|
||||
\"\"\"Sync helper method.\"\"\"
|
||||
return x * 4
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one async helper
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
class AsyncProcessor:
|
||||
async def entrypoint_method(self, n):
|
||||
\"\"\"Optimized async method that only calls one helper.\"\"\"
|
||||
result1 = await self.async_helper_method_1(n)
|
||||
return result1 + n * 3 # Inlined async_helper_method_2
|
||||
|
||||
async def async_helper_method_1(self, x):
|
||||
\"\"\"First async helper method.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_method_2(self, x):
|
||||
\"\"\"Second async helper method - should be unused.\"\"\"
|
||||
return x * 3
|
||||
|
||||
def sync_helper_method(self, x):
|
||||
\"\"\"Sync helper method - should be unused.\"\"\"
|
||||
return x * 4
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async class method
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="entrypoint_method",
|
||||
parents=[FunctionParent(name="AsyncProcessor", type="ClassDef")],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect async_helper_method_2 as unused (sync_helper_method may not be discovered as helper)
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"AsyncProcessor.async_helper_method_2"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_helper_revert_functionality():
|
||||
"""Test that unused async helper functions are correctly reverted to original definitions."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with async functions
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Second async helper function.\"\"\"
|
||||
return x * 3
|
||||
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Async entrypoint function that calls async helpers.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
result2 = await async_helper_2(n)
|
||||
return result1 + result2
|
||||
""")
|
||||
|
||||
# Optimized version that only calls one helper and modifies the unused one
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
async def async_helper_1(x):
|
||||
\"\"\"First async helper function.\"\"\"
|
||||
return x * 2
|
||||
|
||||
async def async_helper_2(x):
|
||||
\"\"\"Modified async helper function - should be reverted.\"\"\"
|
||||
return x * 10 # This change should be reverted
|
||||
|
||||
async def async_entrypoint(n):
|
||||
\"\"\"Optimized async entrypoint that only calls one helper.\"\"\"
|
||||
result1 = await async_helper_1(n)
|
||||
return result1 + n * 3 # Inlined async_helper_2
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="async_entrypoint",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Store original helper code
|
||||
original_helper_code = {main_file: main_file.read_text()}
|
||||
|
||||
# Apply optimization and test reversion
|
||||
optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code
|
||||
)
|
||||
|
||||
# Check final file content
|
||||
final_content = main_file.read_text()
|
||||
|
||||
# The entrypoint should be optimized
|
||||
assert "result1 + n * 3" in final_content, "Async entrypoint function should be optimized"
|
||||
|
||||
# async_helper_2 should be reverted to original (return x * 3, not x * 10)
|
||||
assert "return x * 3" in final_content, "async_helper_2 should be reverted to original"
|
||||
assert "return x * 10" not in final_content, "async_helper_2 should not contain the modified version"
|
||||
|
||||
# async_helper_1 should remain (it's still called)
|
||||
assert "async def async_helper_1(x):" in final_content, "async_helper_1 should still exist"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def test_async_generators_and_coroutines():
|
||||
"""Test detection with async generators and coroutines."""
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
try:
|
||||
# Main file with async generators and coroutines
|
||||
main_file = temp_dir / "main.py"
|
||||
main_file.write_text("""
|
||||
import asyncio
|
||||
|
||||
async def async_generator_helper(n):
|
||||
\"\"\"Async generator helper.\"\"\"
|
||||
for i in range(n):
|
||||
yield i * 2
|
||||
|
||||
async def coroutine_helper(x):
|
||||
\"\"\"Coroutine helper.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 3
|
||||
|
||||
async def another_coroutine_helper(x):
|
||||
\"\"\"Another coroutine helper.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 4
|
||||
|
||||
async def async_entrypoint_with_generators(n):
|
||||
\"\"\"Async entrypoint function that uses generators and coroutines.\"\"\"
|
||||
results = []
|
||||
async for value in async_generator_helper(n):
|
||||
results.append(value)
|
||||
|
||||
final_result = await coroutine_helper(sum(results))
|
||||
another_result = await another_coroutine_helper(n)
|
||||
return final_result + another_result
|
||||
""")
|
||||
|
||||
# Optimized version that doesn't use one of the coroutines
|
||||
optimized_code = """
|
||||
```python:main.py
|
||||
import asyncio
|
||||
|
||||
async def async_generator_helper(n):
|
||||
\"\"\"Async generator helper.\"\"\"
|
||||
for i in range(n):
|
||||
yield i * 2
|
||||
|
||||
async def coroutine_helper(x):
|
||||
\"\"\"Coroutine helper.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 3
|
||||
|
||||
async def another_coroutine_helper(x):
|
||||
\"\"\"Another coroutine helper - should be unused.\"\"\"
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 4
|
||||
|
||||
async def async_entrypoint_with_generators(n):
|
||||
\"\"\"Optimized async entrypoint that inlines one coroutine.\"\"\"
|
||||
results = []
|
||||
async for value in async_generator_helper(n):
|
||||
results.append(value)
|
||||
|
||||
final_result = await coroutine_helper(sum(results))
|
||||
return final_result + n * 4 # Inlined another_coroutine_helper
|
||||
```
|
||||
"""
|
||||
|
||||
# Create test config
|
||||
test_cfg = TestConfig(
|
||||
tests_root=temp_dir / "tests",
|
||||
tests_project_rootdir=temp_dir,
|
||||
project_root_path=temp_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
|
||||
# Create FunctionToOptimize instance for async function
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
file_path=main_file,
|
||||
function_name="async_entrypoint_with_generators",
|
||||
parents=[],
|
||||
is_async=True
|
||||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
)
|
||||
|
||||
# Get original code context
|
||||
ctx_result = optimizer.get_code_optimization_context()
|
||||
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"
|
||||
|
||||
code_context = ctx_result.unwrap()
|
||||
|
||||
# Test unused helper detection
|
||||
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))
|
||||
|
||||
# Should detect another_coroutine_helper as unused
|
||||
unused_names = {uh.qualified_name for uh in unused_helpers}
|
||||
expected_unused = {"another_coroutine_helper"}
|
||||
|
||||
assert unused_names == expected_unused, f"Expected unused: {expected_unused}, got: {unused_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
import shutil
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
|
|
|||
65
tests/test_worktree.py
Normal file
65
tests/test_worktree.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.cli_cmds.cli import process_pyproject_config
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
def test_mirror_paths_for_worktree_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
project_root = repo_root / "code_to_optimize" / "code_directories" / "nested_module_root"
|
||||
|
||||
monkeypatch.setattr("codeflash.optimization.optimizer.git_root_dir", lambda: project_root)
|
||||
|
||||
args = Namespace()
|
||||
args.benchmark = False
|
||||
args.benchmarks_root = None
|
||||
|
||||
args.config_file = project_root / "pyproject.toml"
|
||||
args.file = project_root / "src" / "app" / "main.py"
|
||||
args.worktree = True
|
||||
|
||||
new_args = process_pyproject_config(args)
|
||||
|
||||
optimizer = Optimizer(new_args)
|
||||
|
||||
worktree_dir = repo_root / "worktree"
|
||||
optimizer.mirror_paths_for_worktree_mode(worktree_dir)
|
||||
|
||||
assert optimizer.args.project_root == worktree_dir / "src"
|
||||
assert optimizer.args.test_project_root == worktree_dir / "src"
|
||||
assert optimizer.args.module_root == worktree_dir / "src" / "app"
|
||||
assert optimizer.args.tests_root == worktree_dir / "src" / "tests"
|
||||
assert optimizer.args.file == worktree_dir / "src" / "app" / "main.py"
|
||||
|
||||
assert optimizer.test_cfg.tests_root == worktree_dir / "src" / "tests"
|
||||
assert optimizer.test_cfg.project_root_path == worktree_dir / "src" # same as project_root
|
||||
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir / "src" # same as test_project_root
|
||||
|
||||
# test on our repo
|
||||
monkeypatch.setattr("codeflash.optimization.optimizer.git_root_dir", lambda: repo_root)
|
||||
args = Namespace()
|
||||
args.benchmark = False
|
||||
args.benchmarks_root = None
|
||||
|
||||
args.config_file = repo_root / "pyproject.toml"
|
||||
args.file = repo_root / "codeflash/optimization/optimizer.py"
|
||||
args.worktree = True
|
||||
|
||||
new_args = process_pyproject_config(args)
|
||||
|
||||
optimizer = Optimizer(new_args)
|
||||
|
||||
worktree_dir = repo_root / "worktree"
|
||||
optimizer.mirror_paths_for_worktree_mode(worktree_dir)
|
||||
|
||||
assert optimizer.args.project_root == worktree_dir
|
||||
assert optimizer.args.test_project_root == worktree_dir
|
||||
assert optimizer.args.module_root == worktree_dir / "codeflash"
|
||||
assert optimizer.args.tests_root == worktree_dir / "tests"
|
||||
assert optimizer.args.file == worktree_dir / "codeflash/optimization/optimizer.py"
|
||||
|
||||
assert optimizer.test_cfg.tests_root == worktree_dir / "tests"
|
||||
assert optimizer.test_cfg.project_root_path == worktree_dir # same as project_root
|
||||
assert optimizer.test_cfg.tests_project_rootdir == worktree_dir # same as test_project_root
|
||||
Loading…
Reference in a new issue