Merge branch 'main' of github.com:codeflash-ai/codeflash into multi-language

This commit is contained in:
ali 2026-01-23 14:46:50 +02:00
commit 3ca29563b7
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
18 changed files with 1445 additions and 46 deletions

View file

@ -531,6 +531,10 @@ class AiServiceClient:
optimized_throughput: str | None = None,
throughput_improvement: str | None = None,
function_references: str | None = None,
acceptance_reason: str | None = None,
original_concurrency_ratio: str | None = None,
optimized_concurrency_ratio: str | None = None,
concurrency_improvement: str | None = None,
codeflash_version: str = codeflash_version,
) -> str:
"""Optimize the given python code for performance by making a request to the Django endpoint.
@ -551,8 +555,12 @@ class AiServiceClient:
- original_throughput: str | None - throughput for the baseline code (operations per second)
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
- throughput_improvement: str | None - throughput improvement percentage
- current codeflash version
- function_references: str | None - where the function is called in the codebase
- acceptance_reason: str | None - why the optimization was accepted (runtime, throughput, or concurrency)
- original_concurrency_ratio: str | None - concurrency ratio for the baseline code
- optimized_concurrency_ratio: str | None - concurrency ratio for the optimized code
- concurrency_improvement: str | None - concurrency improvement percentage
- codeflash_version: str - current codeflash version
Returns
-------
@ -576,6 +584,10 @@ class AiServiceClient:
"optimized_throughput": optimized_throughput,
"throughput_improvement": throughput_improvement,
"function_references": function_references,
"acceptance_reason": acceptance_reason,
"original_concurrency_ratio": original_concurrency_ratio,
"optimized_concurrency_ratio": optimized_concurrency_ratio,
"concurrency_improvement": concurrency_improvement,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
}

View file

@ -4,6 +4,7 @@ import asyncio
import gc
import os
import sqlite3
import time
from enum import Enum
from functools import wraps
from pathlib import Path
@ -165,3 +166,45 @@ def codeflash_performance_async(func: F) -> F:
return return_value
return async_wrapper
def codeflash_concurrency_async(func: F) -> F:
"""Measures concurrent vs sequential execution performance for async functions."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
# Phase 1: Sequential execution timing
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
# Phase 2: Concurrent execution timing
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
# Output parseable metrics
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
return result
return async_wrapper

View file

@ -10,12 +10,20 @@ import sentry_sdk
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_temp_dir
# Known CrossHair limitations that produce invalid Python syntax in generated tests:
# - "<locals>" - higher-order functions returning nested functions
# - " object at 0x" - objects with default __repr__
# - "<list_iterator" - iterator objects
CROSSHAIR_KNOWN_LIMITATION_PATTERNS = ("<locals>", " object at 0x", "<list_iterator")
def is_valid_concolic_test(test_code: str, project_root: Optional[str] = None) -> bool:
try:
ast.parse(test_code)
except SyntaxError:
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
is_known_limitation = any(pattern in test_code for pattern in CROSSHAIR_KNOWN_LIMITATION_PATTERNS)
if not is_known_limitation:
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
return False
temp_path = (codeflash_temp_dir / f"concolic_test_{uuid.uuid4().hex}.py").resolve()

View file

@ -10,6 +10,8 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark
MAX_TEST_FUNCTION_RUNS = 50
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget

View file

@ -1439,9 +1439,12 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
self.added_decorator = False
# Choose decorator based on mode
self.decorator_name = (
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
if mode == TestingMode.BEHAVIOR:
self.decorator_name = "codeflash_behavior_async"
elif mode == TestingMode.CONCURRENCY:
self.decorator_name = "codeflash_concurrency_async"
else:
self.decorator_name = "codeflash_performance_async"
def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
@ -1484,12 +1487,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
return decorator_node.func.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
return False
@ -1501,6 +1506,14 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
self.mode = mode
self.has_import = False
def _get_decorator_name(self) -> str:
"""Get the decorator name based on the testing mode."""
if self.mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if self.mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
@ -1512,9 +1525,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
decorator_name = self._get_decorator_name()
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True
@ -1525,9 +1536,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
return updated_node
# Choose import based on mode
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
decorator_name = self._get_decorator_name()
# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")

View file

@ -173,6 +173,7 @@ class BestOptimization(BaseModel):
winning_replay_benchmarking_test_results: Optional[TestResults] = None
line_profiler_test_results: dict
async_throughput: Optional[int] = None
concurrency_metrics: Optional[ConcurrencyMetrics] = None
@dataclass(frozen=True)
@ -184,6 +185,14 @@ class BenchmarkKey:
return f"{self.module_path}::{self.function_name}"
@dataclass
class ConcurrencyMetrics:
sequential_time_ns: int
concurrent_time_ns: int
concurrency_factor: int
concurrency_ratio: float # sequential_time / concurrent_time
@dataclass
class BenchmarkDetail:
benchmark_name: str
@ -381,6 +390,7 @@ class OptimizedCandidateResult(BaseModel):
optimization_candidate_index: int
total_candidate_timing: int
async_throughput: Optional[int] = None
concurrency_metrics: Optional[ConcurrencyMetrics] = None
class GeneratedTests(BaseModel):
@ -602,6 +612,7 @@ class OriginalCodeBaseline(BaseModel):
runtime: int
coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
concurrency_metrics: Optional[ConcurrencyMetrics] = None
class CoverageStatus(Enum):
@ -693,6 +704,7 @@ class TestingMode(enum.Enum):
BEHAVIOR = "behavior"
PERFORMANCE = "performance"
LINE_PROFILE = "line_profile"
CONCURRENCY = "concurrency"
# TODO this class is duplicated in codeflash_capture

View file

@ -103,7 +103,9 @@ from codeflash.models.models import (
)
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import (
concurrency_gain,
coverage_critic,
get_acceptance_reason,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
@ -115,7 +117,11 @@ from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
from codeflash.verification.parse_test_output import (
calculate_function_throughput_from_test_results,
parse_concurrency_metrics,
parse_test_results,
)
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
@ -128,6 +134,7 @@ if TYPE_CHECKING:
from codeflash.models.models import (
BenchmarkKey,
CodeStringsMarkdown,
ConcurrencyMetrics,
CoverageData,
FunctionCalledInTest,
FunctionSource,
@ -833,6 +840,13 @@ class FunctionOptimizer:
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X")
# Display concurrency metrics if available
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
else:
tree.add("This candidate is faster than the original code. 🚀")
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
@ -851,6 +865,14 @@ class FunctionOptimizer:
)
tree.add(f"Async throughput: {candidate_result.async_throughput} executions")
tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%")
# Display concurrency metrics if available
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
tree.add(
f"(Runtime for reference: {humanize_runtime(candidate_result.best_test_runtime)} over "
f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})"
@ -917,6 +939,7 @@ class FunctionOptimizer:
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
async_throughput=candidate_result.async_throughput,
concurrency_metrics=candidate_result.concurrency_metrics,
)
return best_optimization, benchmark_tree
@ -961,6 +984,7 @@ class FunctionOptimizer:
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
async_throughput=valid_opt.async_throughput,
concurrency_metrics=valid_opt.concurrency_metrics,
)
valid_candidates_with_shorter_code.append(new_best_opt)
diff_lens_list.append(
@ -1113,6 +1137,8 @@ class FunctionOptimizer:
best_runtime_until_now=None,
original_async_throughput=original_code_baseline.async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_ratio_until_now=None,
) and quantity_of_tests_critic(candidate_result)
tree = self.build_runtime_info_tree(
@ -1996,6 +2022,14 @@ class FunctionOptimizer:
fto_benchmark_timings=self.function_benchmark_timings,
total_benchmark_timings=self.total_benchmark_timings,
)
acceptance_reason = get_acceptance_reason(
original_runtime_ns=original_code_baseline.runtime,
optimized_runtime_ns=best_optimization.runtime,
original_async_throughput=original_code_baseline.async_throughput,
optimized_async_throughput=best_optimization.async_throughput,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
)
explanation = Explanation(
raw_explanation_message=best_optimization.candidate.explanation,
winning_behavior_test_results=best_optimization.winning_behavior_test_results,
@ -2007,6 +2041,9 @@ class FunctionOptimizer:
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
original_async_throughput=original_code_baseline.async_throughput,
best_async_throughput=best_optimization.async_throughput,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_metrics=best_optimization.concurrency_metrics,
acceptance_reason=acceptance_reason,
)
self.replace_function_and_helpers_with_optimized_code(
@ -2106,6 +2143,9 @@ class FunctionOptimizer:
original_throughput_str = None
optimized_throughput_str = None
throughput_improvement_str = None
original_concurrency_ratio_str = None
optimized_concurrency_ratio_str = None
concurrency_improvement_str = None
if (
self.function_to_optimize.is_async
@ -2120,6 +2160,14 @@ class FunctionOptimizer:
)
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
if original_code_baseline.concurrency_metrics is not None and best_optimization.concurrency_metrics is not None:
original_concurrency_ratio_str = f"{original_code_baseline.concurrency_metrics.concurrency_ratio:.2f}x"
optimized_concurrency_ratio_str = f"{best_optimization.concurrency_metrics.concurrency_ratio:.2f}x"
conc_improvement_value = concurrency_gain(
original_code_baseline.concurrency_metrics, best_optimization.concurrency_metrics
)
concurrency_improvement_str = f"{conc_improvement_value * 100:.1f}%"
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
source_code=code_context.read_writable_code.flat,
dependency_code=code_context.read_only_context_code,
@ -2137,6 +2185,10 @@ class FunctionOptimizer:
optimized_throughput=optimized_throughput_str,
throughput_improvement=throughput_improvement_str,
function_references=function_references,
acceptance_reason=explanation.acceptance_reason.value,
original_concurrency_ratio=original_concurrency_ratio_str,
optimized_concurrency_ratio=optimized_concurrency_ratio_str,
concurrency_improvement=concurrency_improvement_str,
)
new_explanation = Explanation(
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
@ -2149,6 +2201,9 @@ class FunctionOptimizer:
benchmark_details=explanation.benchmark_details,
original_async_throughput=explanation.original_async_throughput,
best_async_throughput=explanation.best_async_throughput,
original_concurrency_metrics=explanation.original_concurrency_metrics,
best_concurrency_metrics=explanation.best_concurrency_metrics,
acceptance_reason=explanation.acceptance_reason,
)
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
@ -2391,12 +2446,22 @@ class FunctionOptimizer:
logger.debug(f"Total original code runtime (ns): {total_timing}")
async_throughput = None
concurrency_metrics = None
if self.function_to_optimize.is_async:
async_throughput = calculate_function_throughput_from_test_results(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
)
if concurrency_metrics:
logger.debug(
f"Original concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
)
if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -2411,6 +2476,7 @@ class FunctionOptimizer:
coverage_results=coverage_results,
line_profile_results=line_profile_results,
async_throughput=async_throughput,
concurrency_metrics=concurrency_metrics,
),
functions_to_remove,
)
@ -2607,12 +2673,23 @@ class FunctionOptimizer:
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
candidate_async_throughput = None
candidate_concurrency_metrics = None
if self.function_to_optimize.is_async:
candidate_async_throughput = calculate_function_throughput_from_test_results(
candidate_benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
# Run concurrency benchmark for candidate
candidate_concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=candidate_helper_code, test_env=test_env
)
if candidate_concurrency_metrics:
logger.debug(
f"Candidate concurrency metrics: ratio={candidate_concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={candidate_concurrency_metrics.sequential_time_ns}ns, conc={candidate_concurrency_metrics.concurrent_time_ns}ns"
)
if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -2633,6 +2710,7 @@ class FunctionOptimizer:
optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing,
async_throughput=candidate_async_throughput,
concurrency_metrics=candidate_concurrency_metrics,
)
)
@ -2912,3 +2990,57 @@ class FunctionOptimizer:
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
return line_profile_results
def run_concurrency_benchmark(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]
) -> ConcurrencyMetrics | None:
"""Run concurrency benchmark to measure sequential vs concurrent execution for async functions.
This benchmark detects blocking vs non-blocking async code by comparing:
- Sequential execution time (running N iterations one after another)
- Concurrent execution time (running N iterations in parallel with asyncio.gather)
Blocking code (like time.sleep) will have similar sequential and concurrent times.
Non-blocking code (like asyncio.sleep) will be much faster when run concurrently.
Returns:
ConcurrencyMetrics if benchmark ran successfully, None otherwise.
"""
if not self.function_to_optimize.is_async:
return None
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
try:
# Add concurrency decorator to the source function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
)
# Run the concurrency benchmark tests
concurrency_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE, # Use performance mode for running
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=5.0, # Short benchmark time
enable_coverage=False,
code_context=code_context,
pytest_min_loops=1,
pytest_max_loops=3,
)
except Exception as e:
logger.debug(f"Concurrency benchmark failed: {e}")
return None
finally:
# Restore original code
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
# Parse concurrency metrics from stdout
if concurrency_results and concurrency_results.perf_stdout:
return parse_concurrency_metrics(concurrency_results, self.function_to_optimize.function_name)
return None

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING
from codeflash.code_utils import env_utils
from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD,
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD,
MIN_IMPROVEMENT_THRESHOLD,
MIN_TESTCASE_PASSED_THRESHOLD,
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
@ -12,7 +14,14 @@ from codeflash.code_utils.config_consts import (
from codeflash.models import models
if TYPE_CHECKING:
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
from codeflash.models.models import ConcurrencyMetrics, CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
class AcceptanceReason(Enum):
RUNTIME = "runtime"
THROUGHPUT = "throughput"
CONCURRENCY = "concurrency"
NONE = "none"
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
@ -36,6 +45,22 @@ def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> f
return (optimized_throughput - original_throughput) / original_throughput
def concurrency_gain(original_metrics: ConcurrencyMetrics, optimized_metrics: ConcurrencyMetrics) -> float:
"""Calculate concurrency ratio improvement.
Returns the relative improvement in concurrency ratio.
Higher is better - means the optimized code scales better with concurrent execution.
concurrency_ratio = sequential_time / concurrent_time
A ratio of 10 means concurrent execution is 10x faster than sequential.
"""
if original_metrics.concurrency_ratio == 0:
return 0.0
return (
optimized_metrics.concurrency_ratio - original_metrics.concurrency_ratio
) / original_metrics.concurrency_ratio
def speedup_critic(
candidate_result: OptimizedCandidateResult,
original_code_runtime: int,
@ -44,10 +69,12 @@ def speedup_critic(
disable_gh_action_noise: bool = False,
original_async_throughput: int | None = None,
best_throughput_until_now: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
best_concurrency_ratio_until_now: float | None = None,
) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
Evaluates both runtime performance and async throughput improvements.
Evaluates runtime performance, async throughput, and concurrency improvements.
For runtime performance:
- Ensures the optimization is actually faster than the original code, above the noise floor.
@ -58,6 +85,10 @@ def speedup_critic(
For async throughput (when available):
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
- Throughput improvements complement runtime improvements for async functions
For concurrency (when available):
- Evaluates concurrency ratio improvements using MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
"""
# Runtime performance evaluation
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
@ -86,14 +117,78 @@ def speedup_critic(
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
)
# Concurrency evaluation
concurrency_improved = False
concurrency_is_best = True
if original_concurrency_metrics is not None and candidate_result.concurrency_metrics is not None:
conc_gain = concurrency_gain(original_concurrency_metrics, candidate_result.concurrency_metrics)
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
concurrency_is_best = (
best_concurrency_ratio_until_now is None
or candidate_result.concurrency_metrics.concurrency_ratio > best_concurrency_ratio_until_now
)
# Accept if ANY of: runtime, throughput, or concurrency improves significantly
if original_async_throughput is not None and candidate_result.async_throughput is not None:
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
throughput_acceptance = throughput_improved and throughput_is_best
runtime_acceptance = runtime_improved and runtime_is_best
return throughput_acceptance or runtime_acceptance
concurrency_acceptance = concurrency_improved and concurrency_is_best
return throughput_acceptance or runtime_acceptance or concurrency_acceptance
return runtime_improved and runtime_is_best
def get_acceptance_reason(
original_runtime_ns: int,
optimized_runtime_ns: int,
*,
original_async_throughput: int | None = None,
optimized_async_throughput: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
) -> AcceptanceReason:
"""Determine why an optimization was accepted.
Returns the primary reason for acceptance, with priority:
concurrency > throughput > runtime (for async code).
"""
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
if env_utils.is_ci():
noise_floor = noise_floor * 2
perf_gain = performance_gain(original_runtime_ns=original_runtime_ns, optimized_runtime_ns=optimized_runtime_ns)
runtime_improved = perf_gain > noise_floor
throughput_improved = False
if (
original_async_throughput is not None
and optimized_async_throughput is not None
and original_async_throughput > 0
):
throughput_gain_value = throughput_gain(
original_throughput=original_async_throughput, optimized_throughput=optimized_async_throughput
)
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
concurrency_improved = False
if original_concurrency_metrics is not None and optimized_concurrency_metrics is not None:
conc_gain = concurrency_gain(original_concurrency_metrics, optimized_concurrency_metrics)
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
# Return reason with priority: concurrency > throughput > runtime
if original_async_throughput is not None and optimized_async_throughput is not None:
if concurrency_improved:
return AcceptanceReason.CONCURRENCY
if throughput_improved:
return AcceptanceReason.THROUGHPUT
if runtime_improved:
return AcceptanceReason.RUNTIME
return AcceptanceReason.NONE
if runtime_improved:
return AcceptanceReason.RUNTIME
return AcceptanceReason.NONE
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
test_results = candidate_result.behavior_test_results
report = test_results.get_test_pass_fail_report_by_type()

View file

@ -11,8 +11,8 @@ from rich.table import Table
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import BenchmarkDetail, TestResults
from codeflash.result.critic import throughput_gain
from codeflash.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults
from codeflash.result.critic import AcceptanceReason, concurrency_gain, throughput_gain
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@ -27,31 +27,44 @@ class Explanation:
benchmark_details: Optional[list[BenchmarkDetail]] = None
original_async_throughput: Optional[int] = None
best_async_throughput: Optional[int] = None
original_concurrency_metrics: Optional[ConcurrencyMetrics] = None
best_concurrency_metrics: Optional[ConcurrencyMetrics] = None
acceptance_reason: AcceptanceReason = AcceptanceReason.RUNTIME
@property
def perf_improvement_line(self) -> str:
# speedup property already handles choosing between runtime and throughput
improvement_type = {
AcceptanceReason.RUNTIME: "runtime",
AcceptanceReason.THROUGHPUT: "throughput",
AcceptanceReason.CONCURRENCY: "concurrency",
AcceptanceReason.NONE: "",
}.get(self.acceptance_reason, "")
if improvement_type:
return f"{self.speedup_pct} {improvement_type} improvement ({self.speedup_x} faster)."
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
@property
def speedup(self) -> float:
runtime_improvement = (self.original_runtime_ns / self.best_runtime_ns) - 1
# Use throughput improvement if we have async metrics and throughput is better
"""Returns the improvement value for the metric that caused acceptance."""
if (
self.original_async_throughput is not None
self.acceptance_reason == AcceptanceReason.CONCURRENCY
and self.original_concurrency_metrics
and self.best_concurrency_metrics
):
return concurrency_gain(self.original_concurrency_metrics, self.best_concurrency_metrics)
if (
self.acceptance_reason == AcceptanceReason.THROUGHPUT
and self.original_async_throughput is not None
and self.best_async_throughput is not None
and self.original_async_throughput > 0
):
throughput_improvement = throughput_gain(
return throughput_gain(
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
)
# Use throughput metrics if throughput improvement is better or runtime got worse
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
return throughput_improvement
return runtime_improvement
return (self.original_runtime_ns / self.best_runtime_ns) - 1
@property
def speedup_x(self) -> str:
@ -108,7 +121,22 @@ class Explanation:
console.print(table)
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
if self.original_async_throughput is not None and self.best_async_throughput is not None:
if (
self.acceptance_reason == AcceptanceReason.CONCURRENCY
and self.original_concurrency_metrics
and self.best_concurrency_metrics
):
orig_ratio = self.original_concurrency_metrics.concurrency_ratio
best_ratio = self.best_concurrency_metrics.concurrency_ratio
performance_description = (
f"Concurrency ratio improved from {orig_ratio:.2f}x to {best_ratio:.2f}x "
f"(concurrent execution now {best_ratio:.2f}x faster than sequential)\n\n"
)
elif (
self.acceptance_reason == AcceptanceReason.THROUGHPUT
and self.original_async_throughput is not None
and self.best_async_throughput is not None
):
performance_description = (
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
f"(runtime: {original_runtime_human}{best_runtime_human})\n\n"

View file

@ -24,6 +24,7 @@ HAS_TORCH = find_spec("torch") is not None
HAS_JAX = find_spec("jax") is not None
HAS_XARRAY = find_spec("xarray") is not None
HAS_TENSORFLOW = find_spec("tensorflow") is not None
HAS_NUMBA = find_spec("numba") is not None
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
# These paths vary between test runs but are logically equivalent
@ -156,6 +157,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
range,
slice,
OrderedDict,
types.GenericAlias,
),
):
return orig == new
@ -296,8 +298,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
if isinstance(orig, (np.floating, np.complexfloating)):
return np.isclose(orig, new, equal_nan=True)
if isinstance(orig, (np.integer, np.bool_, np.byte)):
return orig == new
@ -383,6 +385,42 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if isinstance(orig, torch.device):
return orig == new
if HAS_NUMBA:
import numba # type: ignore # noqa: PGH003
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
# Handle numba typed List
if isinstance(orig, NumbaList):
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
# Handle numba typed Dict
if isinstance(orig, NumbaDict):
if superset_obj:
# Allow new dict to have more keys, but all orig keys must exist with equal values
return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig)
if len(orig) != len(new):
return False
for key in orig:
if key not in new:
return False
if not comparator(orig[key], new[key], superset_obj):
return False
return True
# Handle numba type objects (e.g., numba.int64, numba.float64, numba.Array, etc.)
if isinstance(orig, numba.core.types.Type):
return orig == new
# Handle numba JIT-compiled functions (CPUDispatcher, etc.)
if isinstance(orig, Dispatcher):
# Compare by identity of the underlying Python function
# Two JIT functions are equal if they wrap the same Python function
return orig.py_func is new.py_func
if HAS_PYRSISTENT:
import pyrsistent # type: ignore # noqa: PGH003

View file

@ -20,7 +20,14 @@ from codeflash.code_utils.code_utils import (
module_name_from_file_path,
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
from codeflash.models.models import (
ConcurrencyMetrics,
FunctionTestInvocation,
InvocationId,
TestResults,
TestType,
VerificationType,
)
from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils
if TYPE_CHECKING:
@ -70,6 +77,54 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
return function_throughput
# Pattern for concurrency benchmark output:
# !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
_concurrency_pattern = re.compile(r"!@######CONC:([^:]*):([^:]*):([^:]*):([^:]*):([^:]*):(\d+):(\d+):(\d+)######@!")
def parse_concurrency_metrics(test_results: TestResults, function_name: str) -> ConcurrencyMetrics | None:
"""Parse concurrency benchmark results from test output.
Format: !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
Returns ConcurrencyMetrics with:
- sequential_time_ns: Total time for N sequential executions
- concurrent_time_ns: Total time for N concurrent executions
- concurrency_factor: N (number of concurrent executions)
- concurrency_ratio: sequential_time / concurrent_time (higher = better concurrency)
"""
if not test_results.perf_stdout:
return None
matches = _concurrency_pattern.findall(test_results.perf_stdout)
if not matches:
return None
# Aggregate metrics for the target function
total_seq, total_conc, factor, count = 0, 0, 0, 0
for match in matches:
# match[3] is function_name
if len(match) >= 8 and match[3] == function_name:
total_seq += int(match[5])
total_conc += int(match[6])
factor = int(match[7])
count += 1
if count == 0:
return None
avg_seq = total_seq / count
avg_conc = total_conc / count
ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0
return ConcurrencyMetrics(
sequential_time_ns=int(avg_seq),
concurrent_time_ns=int(avg_conc),
concurrency_factor=factor,
concurrency_ratio=ratio,
)
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
"""Resolve test file path from pytest's test class path.
@ -412,7 +467,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# Default to GENERATED_REGRESSION for Jest tests when test type can't be determined
if test_type is None and is_jest:
test_type = TestType.GENERATED_REGRESSION
logger.debug(f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION (Jest)")
logger.debug("[PARSE-DEBUG] defaulting to GENERATED_REGRESSION (Jest)")
elif test_type is None:
# Skip results where test type cannot be determined
logger.debug(f"Skipping result for {test_function_name}: could not determine test type")
@ -504,7 +559,7 @@ def _extract_jest_console_output(suite_elem) -> str:
return raw_content
# ToDO: {Claude} we need to move to the support directory.
# TODO: {Claude} we need to move to the support directory.
def parse_jest_test_xml(
test_xml_file_path: Path,
test_files: TestFiles,

View file

@ -8,6 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="main.py",
min_improvement_x=0.1,
expected_acceptance_reason="concurrency",
coverage_expectations=[
CoverageExpectation(
function_name="retry_with_backoff",

View file

@ -37,6 +37,7 @@ class TestConfig:
benchmarks_root: Optional[pathlib.Path] = None
use_worktree: bool = False
no_gen_tests: bool = False
expected_acceptance_reason: Optional[str] = None # "runtime", "throughput", "concurrency"
def clear_directory(directory_path: str | pathlib.Path) -> None:
@ -176,7 +177,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
logging.error("Failed to find performance improvement message")
return False
improvement_match = re.search(r"📈 ([\d,]+)% improvement", stdout)
improvement_match = re.search(r"📈 ([\d,]+)% (?:(\w+) )?improvement", stdout)
if not improvement_match:
logging.error("Could not find improvement percentage in output")
return False
@ -193,6 +194,15 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
return False
if config.expected_acceptance_reason is not None:
actual_reason = improvement_match.group(2)
if not actual_reason:
logging.error("Could not find acceptance reason type in output")
return False
if actual_reason != config.expected_acceptance_reason:
logging.error(f"Expected acceptance reason '{config.expected_acceptance_reason}', got '{actual_reason}'")
return False
if config.expected_unit_tests_count is not None:
# Match the global test discovery message from optimizer.py which counts test invocations
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"

View file

@ -0,0 +1,304 @@
from __future__ import annotations
import asyncio
import os
import sys
import time
import pytest
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_concurrency_async
from codeflash.models.models import ConcurrencyMetrics, TestResults
from codeflash.verification.parse_test_output import parse_concurrency_metrics
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestConcurrencyAsyncDecorator:
"""Integration tests for codeflash_concurrency_async decorator."""
@pytest.fixture
def concurrency_env_setup(self, request):
"""Set up environment variables for concurrency testing."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestConcurrencyAsyncDecorator",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CONCURRENCY_FACTOR": "5", # Use smaller factor for faster tests
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.mark.asyncio
async def test_concurrency_decorator_nonblocking_function(self, concurrency_env_setup, capsys):
"""Test that non-blocking async functions show high concurrency ratio."""
@codeflash_concurrency_async
async def nonblocking_sleep(duration: float) -> str:
await asyncio.sleep(duration)
return "done"
result = await nonblocking_sleep(0.01)
assert result == "done"
captured = capsys.readouterr()
output = captured.out
# Verify the output format
assert "!@######CONC:" in output
assert "######@!" in output
# Parse the output manually to verify format
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
assert len(lines) == 1
line = lines[0]
# Format: !@######CONC:{test_module}:{test_class}:{test_function}:{function_name}:{loop_index}:{seq_time}:{conc_time}:{factor}######@!
assert "nonblocking_sleep" in line
assert ":5######@!" in line # concurrency factor
# Extract timing values
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
# parts should be: [test_module, test_class, test_function, function_name, loop_index, seq_time, conc_time, factor]
assert len(parts) == 8
seq_time = int(parts[5])
conc_time = int(parts[6])
factor = int(parts[7])
assert seq_time > 0
assert conc_time > 0
assert factor == 5
# For non-blocking async, concurrent time should be much less than sequential
# Sequential runs 5 iterations of 10ms = ~50ms
# Concurrent runs 5 iterations in parallel = ~10ms
# So ratio should be around 5 (with some overhead tolerance)
ratio = seq_time / conc_time if conc_time > 0 else 1.0
assert ratio > 2.0, f"Non-blocking function should have ratio > 2.0, got {ratio}"
@pytest.mark.asyncio
async def test_concurrency_decorator_blocking_function(self, concurrency_env_setup, capsys):
"""Test that blocking functions show low concurrency ratio (~1.0)."""
@codeflash_concurrency_async
async def blocking_sleep(duration: float) -> str:
time.sleep(duration) # Blocking sleep
return "done"
result = await blocking_sleep(0.005) # 5ms blocking
assert result == "done"
captured = capsys.readouterr()
output = captured.out
assert "!@######CONC:" in output
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
assert len(lines) == 1
line = lines[0]
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
assert len(parts) == 8
seq_time = int(parts[5])
conc_time = int(parts[6])
# For blocking code, sequential and concurrent times should be similar
# Because time.sleep blocks the entire event loop
ratio = seq_time / conc_time if conc_time > 0 else 1.0
# Blocking code should have ratio close to 1.0 (within reasonable tolerance)
assert ratio < 2.0, f"Blocking function should have ratio < 2.0, got {ratio}"
@pytest.mark.asyncio
async def test_concurrency_decorator_with_computation(self, concurrency_env_setup, capsys):
"""Test concurrency with CPU-bound computation."""
@codeflash_concurrency_async
async def compute_intensive(n: int) -> int:
# CPU-bound work (blocked by GIL in concurrent execution)
total = 0
for i in range(n):
total += i * i
return total
result = await compute_intensive(10000)
assert result == sum(i * i for i in range(10000))
captured = capsys.readouterr()
output = captured.out
assert "!@######CONC:" in output
assert "compute_intensive" in output
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestParseConcurrencyMetrics:
"""Integration tests for parse_concurrency_metrics function."""
def test_parse_concurrency_metrics_from_real_output(self):
"""Test parsing concurrency metrics from simulated stdout."""
# Simulate stdout from codeflash_concurrency_async decorator
perf_stdout = """Some other output
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
More output here
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
metrics = parse_concurrency_metrics(test_results, "my_async_func")
assert metrics is not None
assert isinstance(metrics, ConcurrencyMetrics)
assert metrics.sequential_time_ns == 50000000
assert metrics.concurrent_time_ns == 10000000
assert metrics.concurrency_factor == 5
assert metrics.concurrency_ratio == 5.0 # 50M / 10M = 5.0
def test_parse_concurrency_metrics_multiple_entries(self):
"""Test parsing when multiple concurrency entries exist."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:target_func:1:40000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
metrics = parse_concurrency_metrics(test_results, "target_func")
assert metrics is not None
# Should average the two entries for target_func
# (40M + 60M) / 2 = 50M seq, (10M + 10M) / 2 = 10M conc
assert metrics.sequential_time_ns == 50000000
assert metrics.concurrent_time_ns == 10000000
assert metrics.concurrency_ratio == 5.0
def test_parse_concurrency_metrics_no_match(self):
"""Test parsing when function name doesn't match."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
assert metrics is None
def test_parse_concurrency_metrics_empty_stdout(self):
"""Test parsing with empty stdout."""
test_results = TestResults(
test_results=[],
perf_stdout="",
)
metrics = parse_concurrency_metrics(test_results, "any_func")
assert metrics is None
def test_parse_concurrency_metrics_none_stdout(self):
"""Test parsing with None stdout."""
test_results = TestResults(
test_results=[],
perf_stdout=None,
)
metrics = parse_concurrency_metrics(test_results, "any_func")
assert metrics is None
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestConcurrencyRatioComparison:
"""Test comparing blocking vs non-blocking concurrency ratios."""
@pytest.fixture
def comparison_env_setup(self, request):
"""Set up environment variables for comparison testing."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestConcurrencyRatioComparison",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CONCURRENCY_FACTOR": "10",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.mark.asyncio
async def test_blocking_vs_nonblocking_comparison(self, comparison_env_setup, capsys):
"""Compare concurrency ratios between blocking and non-blocking implementations."""
@codeflash_concurrency_async
async def blocking_impl() -> str:
time.sleep(0.002) # 2ms blocking
return "blocking"
@codeflash_concurrency_async
async def nonblocking_impl() -> str:
await asyncio.sleep(0.002) # 2ms non-blocking
return "nonblocking"
# Run blocking version
await blocking_impl()
blocking_output = capsys.readouterr().out
# Run non-blocking version
await nonblocking_impl()
nonblocking_output = capsys.readouterr().out
# Parse blocking metrics
blocking_line = [l for l in blocking_output.split("\n") if "!@######CONC:" in l][0]
blocking_parts = blocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
blocking_seq = int(blocking_parts[5])
blocking_conc = int(blocking_parts[6])
blocking_ratio = blocking_seq / blocking_conc if blocking_conc > 0 else 1.0
# Parse non-blocking metrics
nonblocking_line = [l for l in nonblocking_output.split("\n") if "!@######CONC:" in l][0]
nonblocking_parts = nonblocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
nonblocking_seq = int(nonblocking_parts[5])
nonblocking_conc = int(nonblocking_parts[6])
nonblocking_ratio = nonblocking_seq / nonblocking_conc if nonblocking_conc > 0 else 1.0
# Non-blocking should have significantly higher concurrency ratio
assert nonblocking_ratio > blocking_ratio, (
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than "
f"blocking ratio ({blocking_ratio:.2f})"
)
# The difference should be substantial (non-blocking should be at least 2x better)
ratio_improvement = nonblocking_ratio / blocking_ratio if blocking_ratio > 0 else 0
assert ratio_improvement > 2.0, (
f"Non-blocking should show >2x improvement in concurrency ratio, got {ratio_improvement:.2f}x"
)

View file

@ -17,7 +17,14 @@ import pytest
from codeflash.either import Failure, Success
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
from codeflash.verification.comparator import comparator, _extract_exception_from_message, _get_wrapped_exception
from codeflash.verification.comparator import (
PYTEST_TEMP_PATH_PATTERN,
_extract_exception_from_message,
_get_wrapped_exception,
_is_temp_path,
_normalize_temp_path,
comparator,
)
from codeflash.verification.equivalence import compare_test_results
@ -2911,16 +2918,378 @@ def test_numpy_dtypes() -> None:
assert not comparator(dtypes.Int32DType(), np.dtype('float32'))
def test_numpy_extended_precision_types() -> None:
"""Test comparator for numpy extended precision types like clongdouble."""
try:
import numpy as np
except ImportError:
pytest.skip("numpy not available")
# Test np.clongdouble (extended precision complex)
c1 = np.clongdouble(1 + 2j)
c2 = np.clongdouble(1 + 2j)
c3 = np.clongdouble(1 + 3j)
assert comparator(c1, c2)
assert not comparator(c1, c3)
# Test np.longdouble (extended precision float)
l1 = np.longdouble(1.5)
l2 = np.longdouble(1.5)
l3 = np.longdouble(2.5)
assert comparator(l1, l2)
assert not comparator(l1, l3)
# Test NaN handling for extended precision complex
nan_c1 = np.clongdouble(complex(np.nan, 2))
nan_c2 = np.clongdouble(complex(np.nan, 2))
assert comparator(nan_c1, nan_c2)
# Test NaN handling for extended precision float
nan_l1 = np.longdouble(np.nan)
nan_l2 = np.longdouble(np.nan)
assert comparator(nan_l1, nan_l2)
def test_numpy_typing_types() -> None:
"""Test comparator for numpy.typing types like NDArray type aliases."""
try:
import numpy as np
import numpy.typing as npt
except ImportError:
pytest.skip("numpy or numpy.typing not available")
# Test NDArray type alias comparisons
arr_type1 = npt.NDArray[np.float64]
arr_type2 = npt.NDArray[np.float64]
arr_type3 = npt.NDArray[np.int64]
assert comparator(arr_type1, arr_type2)
assert not comparator(arr_type1, arr_type3)
# Test NBitBase (if it can be instantiated)
try:
nbit1 = npt.NBitBase()
nbit2 = npt.NBitBase()
# NBitBase instances with empty __dict__ should compare as equal
assert comparator(nbit1, nbit2)
# Also test with superset_obj=True
assert comparator(nbit1, nbit2, superset_obj=True)
except TypeError:
# NBitBase may not be instantiable in all numpy versions
pass
def test_numpy_typing_superset_obj() -> None:
"""Test comparator with superset_obj=True for numpy types."""
try:
import numpy as np
import numpy.typing as npt
except ImportError:
pytest.skip("numpy or numpy.typing not available")
# Test numpy arrays with object dtype containing dicts (superset scenario)
a1 = np.array([{'a': 1}], dtype=object)
a2 = np.array([{'a': 1, 'b': 2}], dtype=object) # superset
assert comparator(a1, a2, superset_obj=True)
assert not comparator(a1, a2, superset_obj=False)
# Test extended precision types with superset_obj=True
c1 = np.clongdouble(1 + 2j)
c2 = np.clongdouble(1 + 2j)
assert comparator(c1, c2, superset_obj=True)
l1 = np.longdouble(1.5)
l2 = np.longdouble(1.5)
assert comparator(l1, l2, superset_obj=True)
# Test NDArray type alias with superset_obj=True
arr_type1 = npt.NDArray[np.float64]
arr_type2 = npt.NDArray[np.float64]
assert comparator(arr_type1, arr_type2, superset_obj=True)
# Test numpy structured arrays (np.void) with superset_obj=True
dt = np.dtype([('name', 'S10'), ('age', np.int32)])
a_struct = np.array([('Alice', 25)], dtype=dt)
b_struct = np.array([('Alice', 25)], dtype=dt)
assert comparator(a_struct[0], b_struct[0], superset_obj=True)
# Test numpy random generators with superset_obj=True
rng1 = np.random.default_rng(seed=42)
rng2 = np.random.default_rng(seed=42)
assert comparator(rng1, rng2, superset_obj=True)
rs1 = np.random.RandomState(seed=42)
rs2 = np.random.RandomState(seed=42)
assert comparator(rs1, rs2, superset_obj=True)
def test_numba_typed_list() -> None:
"""Test comparator for numba.typed.List."""
try:
import numba
from numba.typed import List as NumbaList
except ImportError:
pytest.skip("numba not available")
# Test equal lists
a = NumbaList([1, 2, 3])
b = NumbaList([1, 2, 3])
assert comparator(a, b)
# Test different values
c = NumbaList([1, 2, 4])
assert not comparator(a, c)
# Test different lengths
d = NumbaList([1, 2, 3, 4])
assert not comparator(a, d)
# Test empty lists
e = NumbaList.empty_list(item_type=numba.int64)
f = NumbaList.empty_list(item_type=numba.int64)
assert comparator(e, f)
# Test nested values (floats)
g = NumbaList([1.0, 2.0, 3.0])
h = NumbaList([1.0, 2.0, 3.0])
assert comparator(g, h)
i = NumbaList([1.0, 2.0, 4.0])
assert not comparator(g, i)
def test_numba_typed_dict() -> None:
"""Test comparator for numba.typed.Dict."""
try:
import numba
from numba.typed import Dict as NumbaDict
except ImportError:
pytest.skip("numba not available")
# Test equal dicts
a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
a["x"] = 1
a["y"] = 2
b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
b["x"] = 1
b["y"] = 2
assert comparator(a, b)
# Test different values
c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
c["x"] = 1
c["y"] = 3
assert not comparator(a, c)
# Test different keys
d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
d["x"] = 1
d["z"] = 2
assert not comparator(a, d)
# Test different lengths
e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
e["x"] = 1
assert not comparator(a, e)
# Test empty dicts
f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
assert comparator(f, g)
def test_numba_types() -> None:
"""Test comparator for numba type objects."""
try:
import numba
from numba import types
except ImportError:
pytest.skip("numba not available")
# Test basic numeric types from numba module
assert comparator(numba.int64, numba.int64)
assert comparator(numba.float64, numba.float64)
assert comparator(numba.int32, numba.int32)
assert comparator(numba.float32, numba.float32)
# Test basic numeric types from numba.types module
assert comparator(types.int64, types.int64)
assert comparator(types.float64, types.float64)
assert comparator(types.int8, types.int8)
assert comparator(types.int16, types.int16)
assert comparator(types.uint8, types.uint8)
assert comparator(types.uint16, types.uint16)
assert comparator(types.uint32, types.uint32)
assert comparator(types.uint64, types.uint64)
assert comparator(types.complex64, types.complex64)
assert comparator(types.complex128, types.complex128)
# Test different types
assert not comparator(numba.int64, numba.float64)
assert not comparator(numba.int32, numba.int64)
assert not comparator(numba.float32, numba.float64)
assert not comparator(types.int8, types.int16)
assert not comparator(types.uint32, types.int32)
assert not comparator(types.complex64, types.complex128)
# Test boolean type
assert comparator(numba.boolean, numba.boolean)
assert comparator(types.boolean, types.boolean)
assert not comparator(numba.boolean, numba.int64)
# Test special types
assert comparator(types.none, types.none)
assert comparator(types.void, types.void)
assert comparator(types.pyobject, types.pyobject)
assert comparator(types.unicode_type, types.unicode_type)
# Note: types.none and types.void are the same object in numba
assert comparator(types.none, types.void)
assert not comparator(types.unicode_type, types.pyobject)
assert not comparator(types.none, types.int64)
# Test array types
arr_type1 = types.Array(numba.float64, 1, 'C')
arr_type2 = types.Array(numba.float64, 1, 'C')
arr_type3 = types.Array(numba.float64, 2, 'C')
arr_type4 = types.Array(numba.int64, 1, 'C')
arr_type5 = types.Array(numba.float64, 1, 'F') # Fortran order
assert comparator(arr_type1, arr_type2)
assert not comparator(arr_type1, arr_type3) # different ndim
assert not comparator(arr_type1, arr_type4) # different dtype
assert not comparator(arr_type1, arr_type5) # different layout
# Test tuple types
tuple_type1 = types.UniTuple(types.int64, 3)
tuple_type2 = types.UniTuple(types.int64, 3)
tuple_type3 = types.UniTuple(types.int64, 4)
tuple_type4 = types.UniTuple(types.float64, 3)
assert comparator(tuple_type1, tuple_type2)
assert not comparator(tuple_type1, tuple_type3) # different count
assert not comparator(tuple_type1, tuple_type4) # different dtype
# Test heterogeneous tuple types
hetero_tuple1 = types.Tuple([types.int64, types.float64])
hetero_tuple2 = types.Tuple([types.int64, types.float64])
hetero_tuple3 = types.Tuple([types.int64, types.int64])
assert comparator(hetero_tuple1, hetero_tuple2)
assert not comparator(hetero_tuple1, hetero_tuple3)
# Test ListType and DictType
list_type1 = types.ListType(types.int64)
list_type2 = types.ListType(types.int64)
list_type3 = types.ListType(types.float64)
assert comparator(list_type1, list_type2)
assert not comparator(list_type1, list_type3)
dict_type1 = types.DictType(types.unicode_type, types.int64)
dict_type2 = types.DictType(types.unicode_type, types.int64)
dict_type3 = types.DictType(types.unicode_type, types.float64)
dict_type4 = types.DictType(types.int64, types.int64)
assert comparator(dict_type1, dict_type2)
assert not comparator(dict_type1, dict_type3) # different value type
assert not comparator(dict_type1, dict_type4) # different key type
def test_numba_jit_functions() -> None:
"""Test comparator for numba JIT-compiled functions."""
try:
from numba import jit
except ImportError:
pytest.skip("numba not available")
@jit(nopython=True)
def add(x, y):
return x + y
@jit(nopython=True)
def add2(x, y):
return x + y
@jit(nopython=True)
def multiply(x, y):
return x * y
# Compile the functions by calling them
add(1, 2)
add2(1, 2)
multiply(1, 2)
# Same function should compare equal to itself
assert comparator(add, add)
# Different functions (even with same code) should not compare equal
# since they are distinct function objects
assert not comparator(add, add2)
# Different functions with different code should not compare equal
assert not comparator(add, multiply)
def test_numba_superset_obj() -> None:
"""Test comparator for numba types with superset_obj=True."""
try:
import numba
from numba.typed import Dict as NumbaDict
from numba.typed import List as NumbaList
except ImportError:
pytest.skip("numba not available")
# Test NumbaDict with superset_obj=True
orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
orig_dict["x"] = 1
orig_dict["y"] = 2
# New dict with same keys - should pass
new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_same["x"] = 1
new_dict_same["y"] = 2
assert comparator(orig_dict, new_dict_same, superset_obj=True)
# New dict with extra keys - should pass with superset_obj=True
new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_superset["x"] = 1
new_dict_superset["y"] = 2
new_dict_superset["z"] = 3
assert comparator(orig_dict, new_dict_superset, superset_obj=True)
# But should fail with superset_obj=False
assert not comparator(orig_dict, new_dict_superset, superset_obj=False)
# New dict missing keys - should fail even with superset_obj=True
new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_subset["x"] = 1
assert not comparator(orig_dict, new_dict_subset, superset_obj=True)
# New dict with different values - should fail
new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_diff["x"] = 1
new_dict_diff["y"] = 99
assert not comparator(orig_dict, new_dict_diff, superset_obj=True)
# Test NumbaList with superset_obj=True (lists don't support superset semantics)
orig_list = NumbaList([1, 2, 3])
new_list_same = NumbaList([1, 2, 3])
new_list_longer = NumbaList([1, 2, 3, 4])
assert comparator(orig_list, new_list_same, superset_obj=True)
# Lists must have same length regardless of superset_obj
assert not comparator(orig_list, new_list_longer, superset_obj=True)
# Test empty dict with superset_obj=True
empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
non_empty_new["a"] = 1
# Empty orig should match any superset
assert comparator(empty_orig, non_empty_new, superset_obj=True)
assert not comparator(empty_orig, non_empty_new, superset_obj=False)
# =============================================================================
# Tests for pytest temp path normalization (lines 28-69 in comparator.py)
# =============================================================================
from codeflash.verification.comparator import (
PYTEST_TEMP_PATH_PATTERN,
_is_temp_path,
_normalize_temp_path,
)
class TestIsTempPath:
"""Tests for the _is_temp_path() function."""

View file

@ -5,6 +5,7 @@ from unittest.mock import Mock
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.models.models import (
CodeOptimizationContext,
ConcurrencyMetrics,
CoverageData,
CoverageStatus,
FunctionCoverage,
@ -15,12 +16,14 @@ from codeflash.models.models import (
TestType,
)
from codeflash.result.critic import (
concurrency_gain,
coverage_critic,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
throughput_gain,
)
from codeflash.verification.parse_test_output import parse_concurrency_metrics
def test_performance_gain() -> None:
@ -569,3 +572,238 @@ def test_speedup_critic_with_async_throughput() -> None:
best_throughput_until_now=None,
disable_gh_action_noise=True
)
def test_concurrency_gain() -> None:
"""Test concurrency_gain calculation."""
# Test basic concurrency improvement (blocking -> non-blocking)
original = ConcurrencyMetrics(
sequential_time_ns=10_000_000, # 10ms
concurrent_time_ns=10_000_000, # 10ms (no speedup - blocking)
concurrency_factor=10,
concurrency_ratio=1.0, # sequential/concurrent = 1.0
)
optimized = ConcurrencyMetrics(
sequential_time_ns=10_000_000, # 10ms
concurrent_time_ns=1_000_000, # 1ms (10x speedup - non-blocking)
concurrency_factor=10,
concurrency_ratio=10.0, # sequential/concurrent = 10.0
)
# 900% improvement: (10 - 1) / 1 = 9.0
assert concurrency_gain(original, optimized) == 9.0
# Test no improvement
same = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0,
)
assert concurrency_gain(original, same) == 0.0
# Test slight improvement
slightly_better = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=8_000_000,
concurrency_factor=10,
concurrency_ratio=1.25,
)
# 25% improvement: (1.25 - 1.0) / 1.0 = 0.25
assert concurrency_gain(original, slightly_better) == 0.25
# Test zero original ratio (edge case)
zero_ratio = ConcurrencyMetrics(
sequential_time_ns=0,
concurrent_time_ns=1_000_000,
concurrency_factor=10,
concurrency_ratio=0.0,
)
assert concurrency_gain(zero_ratio, optimized) == 0.0
def test_speedup_critic_with_concurrency_metrics() -> None:
"""Test speedup_critic with concurrency metrics evaluation."""
original_code_runtime = 10000 # 10 microseconds
original_async_throughput = 100
# Original concurrency metrics (blocking code - ratio ~= 1.0)
original_concurrency = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0,
)
# Test case 1: Concurrency improves significantly (blocking -> non-blocking)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=10000, # Same runtime
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=10000,
async_throughput=100, # Same throughput
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=1_000_000, # 10x faster concurrent execution
concurrency_factor=10,
concurrency_ratio=10.0, # 900% improvement
),
)
# Should pass due to concurrency improvement even though runtime/throughput unchanged
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=None,
disable_gh_action_noise=True,
)
# Test case 2: No concurrency improvement (should fall back to other metrics)
candidate_result_no_conc = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=100,
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0, # No improvement
),
)
# Should pass due to runtime improvement
assert speedup_critic(
candidate_result=candidate_result_no_conc,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=None,
disable_gh_action_noise=True,
)
# Test case 3: Concurrency below threshold (20% required)
candidate_result_below_threshold = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=10000, # Same runtime
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=10000,
async_throughput=100, # Same throughput
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=9_000_000, # Only 11% improvement
concurrency_factor=10,
concurrency_ratio=1.11,
),
)
# Should fail - no metric improves enough
assert not speedup_critic(
candidate_result=candidate_result_below_threshold,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=None,
disable_gh_action_noise=True,
)
# Test case 4: best_concurrency_ratio_until_now comparison
candidate_result_good = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=10000,
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=10000,
async_throughput=100,
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=2_000_000,
concurrency_factor=10,
concurrency_ratio=5.0,
),
)
# Should fail when there's a better concurrency ratio already
assert not speedup_critic(
candidate_result=candidate_result_good,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=10.0, # Better ratio already exists
disable_gh_action_noise=True,
)
def test_concurrency_ratio_display_formatting() -> None:
orig_ratio = 0.05
cand_ratio = 0.15
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
assert display_string == "Concurrency ratio: 0.05x → 0.15x (+200.0%)"
orig_ratio = 1.0
cand_ratio = 10.0
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
assert display_string == "Concurrency ratio: 1.00x → 10.00x (+900.0%)"
orig_ratio = 0.01
cand_ratio = 0.03
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
assert display_string == "Concurrency ratio: 0.01x → 0.03x (+200.0%)"
def test_parse_concurrency_metrics() -> None:
"""Test parse_concurrency_metrics function."""
# Test with valid concurrency output
stdout = (
"!@######CONC:test_module:TestClass:test_func:my_function:0:10000000:1000000:10######@!\n"
"!@######CONC:test_module:TestClass:test_func:my_function:1:10000000:1000000:10######@!\n"
)
test_results = TestResults(perf_stdout=stdout)
metrics = parse_concurrency_metrics(test_results, "my_function")
assert metrics is not None
assert metrics.sequential_time_ns == 10_000_000 # Average of both matches
assert metrics.concurrent_time_ns == 1_000_000
assert metrics.concurrency_factor == 10
assert metrics.concurrency_ratio == 10.0 # 10000000 / 1000000
# Test with no matching function
metrics_wrong_func = parse_concurrency_metrics(test_results, "other_function")
assert metrics_wrong_func is None
# Test with empty stdout
empty_results = TestResults(perf_stdout="")
metrics_empty = parse_concurrency_metrics(empty_results, "my_function")
assert metrics_empty is None
# Test with None stdout
none_results = TestResults(perf_stdout=None)
metrics_none = parse_concurrency_metrics(none_results, "my_function")
assert metrics_none is None
# Test with no class name
stdout_no_class = "!@######CONC:test_module::test_func:my_function:0:5000000:2500000:10######@!\n"
test_results_no_class = TestResults(perf_stdout=stdout_no_class)
metrics_no_class = parse_concurrency_metrics(test_results_no_class, "my_function")
assert metrics_no_class is not None
assert metrics_no_class.concurrency_ratio == 2.0 # 5000000 / 2500000

View file

@ -131,6 +131,46 @@ async def async_function(x: int, y: int) -> int:
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_application_concurrency_mode(temp_dir):
"""Test that CONCURRENCY mode applies the codeflash_concurrency_async decorator."""
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_concurrency_async
@codeflash_concurrency_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_function_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.CONCURRENCY)
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_class_method_decorator_application(temp_dir):
async_class_code = '''

View file

@ -5640,11 +5640,14 @@ wheels = [
[[package]]
name = "wheel"
version = "0.45.1"
version = "0.46.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545, upload-time = "2024-11-23T00:18:23.513Z" }
dependencies = [
{ name = "packaging" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3a64fa9639b8e290fe8630d8067a66f7c5510845c6d73686ad880c9b04d9/wheel-0.46.2.tar.gz", hash = "sha256:3d79e48fde9847618a5a181f3cc35764c349c752e2fe911e65fa17faab9809b0", size = 60274, upload-time = "2026-01-21T23:55:25.838Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494, upload-time = "2024-11-23T00:18:21.207Z" },
{ url = "https://files.pythonhosted.org/packages/13/2c/5e079cefe955ae58e5a052fe037c850ce493eb7269dedeb960237e78fb0f/wheel-0.46.2-py3-none-any.whl", hash = "sha256:33ae60725d69eaa249bc1982e739943c23b34b58d51f1cb6253453773aca6e65", size = 29971, upload-time = "2026-01-21T23:55:24.447Z" },
]
[[package]]