mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' of github.com:codeflash-ai/codeflash into multi-language
This commit is contained in:
commit
3ca29563b7
18 changed files with 1445 additions and 46 deletions
|
|
@ -531,6 +531,10 @@ class AiServiceClient:
|
|||
optimized_throughput: str | None = None,
|
||||
throughput_improvement: str | None = None,
|
||||
function_references: str | None = None,
|
||||
acceptance_reason: str | None = None,
|
||||
original_concurrency_ratio: str | None = None,
|
||||
optimized_concurrency_ratio: str | None = None,
|
||||
concurrency_improvement: str | None = None,
|
||||
codeflash_version: str = codeflash_version,
|
||||
) -> str:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
|
@ -551,8 +555,12 @@ class AiServiceClient:
|
|||
- original_throughput: str | None - throughput for the baseline code (operations per second)
|
||||
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
|
||||
- throughput_improvement: str | None - throughput improvement percentage
|
||||
- current codeflash version
|
||||
- function_references: str | None - where the function is called in the codebase
|
||||
- acceptance_reason: str | None - why the optimization was accepted (runtime, throughput, or concurrency)
|
||||
- original_concurrency_ratio: str | None - concurrency ratio for the baseline code
|
||||
- optimized_concurrency_ratio: str | None - concurrency ratio for the optimized code
|
||||
- concurrency_improvement: str | None - concurrency improvement percentage
|
||||
- codeflash_version: str - current codeflash version
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -576,6 +584,10 @@ class AiServiceClient:
|
|||
"optimized_throughput": optimized_throughput,
|
||||
"throughput_improvement": throughput_improvement,
|
||||
"function_references": function_references,
|
||||
"acceptance_reason": acceptance_reason,
|
||||
"original_concurrency_ratio": original_concurrency_ratio,
|
||||
"optimized_concurrency_ratio": optimized_concurrency_ratio,
|
||||
"concurrency_improvement": concurrency_improvement,
|
||||
"codeflash_version": codeflash_version,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
|
@ -165,3 +166,45 @@ def codeflash_performance_async(func: F) -> F:
|
|||
return return_value
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func: F) -> F:
|
||||
"""Measures concurrent vs sequential execution performance for async functions."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
|
||||
|
||||
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
|
||||
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
|
||||
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
|
||||
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
|
||||
|
||||
# Phase 1: Sequential execution timing
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
# Phase 2: Concurrent execution timing
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
# Output parseable metrics
|
||||
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
|
||||
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper
|
||||
|
|
|
|||
|
|
@ -10,12 +10,20 @@ import sentry_sdk
|
|||
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_temp_dir
|
||||
|
||||
# Known CrossHair limitations that produce invalid Python syntax in generated tests:
|
||||
# - "<locals>" - higher-order functions returning nested functions
|
||||
# - " object at 0x" - objects with default __repr__
|
||||
# - "<list_iterator" - iterator objects
|
||||
CROSSHAIR_KNOWN_LIMITATION_PATTERNS = ("<locals>", " object at 0x", "<list_iterator")
|
||||
|
||||
|
||||
def is_valid_concolic_test(test_code: str, project_root: Optional[str] = None) -> bool:
|
||||
try:
|
||||
ast.parse(test_code)
|
||||
except SyntaxError:
|
||||
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
|
||||
is_known_limitation = any(pattern in test_code for pattern in CROSSHAIR_KNOWN_LIMITATION_PATTERNS)
|
||||
if not is_known_limitation:
|
||||
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
|
||||
return False
|
||||
|
||||
temp_path = (codeflash_temp_dir / f"concolic_test_{uuid.uuid4().hex}.py").resolve()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
|||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
|
||||
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark
|
||||
MAX_TEST_FUNCTION_RUNS = 50
|
||||
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
|
||||
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
|
||||
|
|
|
|||
|
|
@ -1439,9 +1439,12 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
self.added_decorator = False
|
||||
|
||||
# Choose decorator based on mode
|
||||
self.decorator_name = (
|
||||
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
self.decorator_name = "codeflash_behavior_async"
|
||||
elif mode == TestingMode.CONCURRENCY:
|
||||
self.decorator_name = "codeflash_concurrency_async"
|
||||
else:
|
||||
self.decorator_name = "codeflash_performance_async"
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
# Track when we enter a class
|
||||
|
|
@ -1484,12 +1487,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
|
||||
return decorator_node.func.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
return False
|
||||
|
||||
|
|
@ -1501,6 +1506,14 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
self.mode = mode
|
||||
self.has_import = False
|
||||
|
||||
def _get_decorator_name(self) -> str:
|
||||
"""Get the decorator name based on the testing mode."""
|
||||
if self.mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_async"
|
||||
if self.mode == TestingMode.CONCURRENCY:
|
||||
return "codeflash_concurrency_async"
|
||||
return "codeflash_performance_async"
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
# Check if the async decorator import is already present
|
||||
if (
|
||||
|
|
@ -1512,9 +1525,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
and node.module.attr.value == "codeflash_wrap_decorator"
|
||||
and not isinstance(node.names, cst.ImportStar)
|
||||
):
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
decorator_name = self._get_decorator_name()
|
||||
for import_alias in node.names:
|
||||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
|
|
@ -1525,9 +1536,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
# Choose import based on mode
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
decorator_name = self._get_decorator_name()
|
||||
|
||||
# Parse the import statement into a CST node
|
||||
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
|
||||
|
|
|
|||
|
|
@ -173,6 +173,7 @@ class BestOptimization(BaseModel):
|
|||
winning_replay_benchmarking_test_results: Optional[TestResults] = None
|
||||
line_profiler_test_results: dict
|
||||
async_throughput: Optional[int] = None
|
||||
concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -184,6 +185,14 @@ class BenchmarkKey:
|
|||
return f"{self.module_path}::{self.function_name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyMetrics:
|
||||
sequential_time_ns: int
|
||||
concurrent_time_ns: int
|
||||
concurrency_factor: int
|
||||
concurrency_ratio: float # sequential_time / concurrent_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDetail:
|
||||
benchmark_name: str
|
||||
|
|
@ -381,6 +390,7 @@ class OptimizedCandidateResult(BaseModel):
|
|||
optimization_candidate_index: int
|
||||
total_candidate_timing: int
|
||||
async_throughput: Optional[int] = None
|
||||
concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
|
||||
|
||||
class GeneratedTests(BaseModel):
|
||||
|
|
@ -602,6 +612,7 @@ class OriginalCodeBaseline(BaseModel):
|
|||
runtime: int
|
||||
coverage_results: Optional[CoverageData]
|
||||
async_throughput: Optional[int] = None
|
||||
concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
|
||||
|
||||
class CoverageStatus(Enum):
|
||||
|
|
@ -693,6 +704,7 @@ class TestingMode(enum.Enum):
|
|||
BEHAVIOR = "behavior"
|
||||
PERFORMANCE = "performance"
|
||||
LINE_PROFILE = "line_profile"
|
||||
CONCURRENCY = "concurrency"
|
||||
|
||||
|
||||
# TODO this class is duplicated in codeflash_capture
|
||||
|
|
|
|||
|
|
@ -103,7 +103,9 @@ from codeflash.models.models import (
|
|||
)
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import (
|
||||
concurrency_gain,
|
||||
coverage_critic,
|
||||
get_acceptance_reason,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
|
|
@ -115,7 +117,11 @@ from codeflash.verification.concolic_testing import generate_concolic_tests
|
|||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
|
||||
from codeflash.verification.parse_test_output import (
|
||||
calculate_function_throughput_from_test_results,
|
||||
parse_concurrency_metrics,
|
||||
parse_test_results,
|
||||
)
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
|
@ -128,6 +134,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.models.models import (
|
||||
BenchmarkKey,
|
||||
CodeStringsMarkdown,
|
||||
ConcurrencyMetrics,
|
||||
CoverageData,
|
||||
FunctionCalledInTest,
|
||||
FunctionSource,
|
||||
|
|
@ -833,6 +840,13 @@ class FunctionOptimizer:
|
|||
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
|
||||
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
|
||||
tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X")
|
||||
|
||||
# Display concurrency metrics if available
|
||||
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
|
||||
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
|
||||
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
|
||||
else:
|
||||
tree.add("This candidate is faster than the original code. 🚀")
|
||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||
|
|
@ -851,6 +865,14 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Async throughput: {candidate_result.async_throughput} executions")
|
||||
tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%")
|
||||
|
||||
# Display concurrency metrics if available
|
||||
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
|
||||
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
|
||||
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
|
||||
|
||||
tree.add(
|
||||
f"(Runtime for reference: {humanize_runtime(candidate_result.best_test_runtime)} over "
|
||||
f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||
|
|
@ -917,6 +939,7 @@ class FunctionOptimizer:
|
|||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
async_throughput=candidate_result.async_throughput,
|
||||
concurrency_metrics=candidate_result.concurrency_metrics,
|
||||
)
|
||||
|
||||
return best_optimization, benchmark_tree
|
||||
|
|
@ -961,6 +984,7 @@ class FunctionOptimizer:
|
|||
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
|
||||
async_throughput=valid_opt.async_throughput,
|
||||
concurrency_metrics=valid_opt.concurrency_metrics,
|
||||
)
|
||||
valid_candidates_with_shorter_code.append(new_best_opt)
|
||||
diff_lens_list.append(
|
||||
|
|
@ -1113,6 +1137,8 @@ class FunctionOptimizer:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
) and quantity_of_tests_critic(candidate_result)
|
||||
|
||||
tree = self.build_runtime_info_tree(
|
||||
|
|
@ -1996,6 +2022,14 @@ class FunctionOptimizer:
|
|||
fto_benchmark_timings=self.function_benchmark_timings,
|
||||
total_benchmark_timings=self.total_benchmark_timings,
|
||||
)
|
||||
acceptance_reason = get_acceptance_reason(
|
||||
original_runtime_ns=original_code_baseline.runtime,
|
||||
optimized_runtime_ns=best_optimization.runtime,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
optimized_async_throughput=best_optimization.async_throughput,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
|
||||
)
|
||||
explanation = Explanation(
|
||||
raw_explanation_message=best_optimization.candidate.explanation,
|
||||
winning_behavior_test_results=best_optimization.winning_behavior_test_results,
|
||||
|
|
@ -2007,6 +2041,9 @@ class FunctionOptimizer:
|
|||
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_async_throughput=best_optimization.async_throughput,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
best_concurrency_metrics=best_optimization.concurrency_metrics,
|
||||
acceptance_reason=acceptance_reason,
|
||||
)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
|
|
@ -2106,6 +2143,9 @@ class FunctionOptimizer:
|
|||
original_throughput_str = None
|
||||
optimized_throughput_str = None
|
||||
throughput_improvement_str = None
|
||||
original_concurrency_ratio_str = None
|
||||
optimized_concurrency_ratio_str = None
|
||||
concurrency_improvement_str = None
|
||||
|
||||
if (
|
||||
self.function_to_optimize.is_async
|
||||
|
|
@ -2120,6 +2160,14 @@ class FunctionOptimizer:
|
|||
)
|
||||
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
|
||||
|
||||
if original_code_baseline.concurrency_metrics is not None and best_optimization.concurrency_metrics is not None:
|
||||
original_concurrency_ratio_str = f"{original_code_baseline.concurrency_metrics.concurrency_ratio:.2f}x"
|
||||
optimized_concurrency_ratio_str = f"{best_optimization.concurrency_metrics.concurrency_ratio:.2f}x"
|
||||
conc_improvement_value = concurrency_gain(
|
||||
original_code_baseline.concurrency_metrics, best_optimization.concurrency_metrics
|
||||
)
|
||||
concurrency_improvement_str = f"{conc_improvement_value * 100:.1f}%"
|
||||
|
||||
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
|
||||
source_code=code_context.read_writable_code.flat,
|
||||
dependency_code=code_context.read_only_context_code,
|
||||
|
|
@ -2137,6 +2185,10 @@ class FunctionOptimizer:
|
|||
optimized_throughput=optimized_throughput_str,
|
||||
throughput_improvement=throughput_improvement_str,
|
||||
function_references=function_references,
|
||||
acceptance_reason=explanation.acceptance_reason.value,
|
||||
original_concurrency_ratio=original_concurrency_ratio_str,
|
||||
optimized_concurrency_ratio=optimized_concurrency_ratio_str,
|
||||
concurrency_improvement=concurrency_improvement_str,
|
||||
)
|
||||
new_explanation = Explanation(
|
||||
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
|
||||
|
|
@ -2149,6 +2201,9 @@ class FunctionOptimizer:
|
|||
benchmark_details=explanation.benchmark_details,
|
||||
original_async_throughput=explanation.original_async_throughput,
|
||||
best_async_throughput=explanation.best_async_throughput,
|
||||
original_concurrency_metrics=explanation.original_concurrency_metrics,
|
||||
best_concurrency_metrics=explanation.best_concurrency_metrics,
|
||||
acceptance_reason=explanation.acceptance_reason,
|
||||
)
|
||||
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
|
||||
|
||||
|
|
@ -2391,12 +2446,22 @@ class FunctionOptimizer:
|
|||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
|
||||
async_throughput = None
|
||||
concurrency_metrics = None
|
||||
if self.function_to_optimize.is_async:
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
|
||||
|
||||
concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
|
||||
)
|
||||
if concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Original concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
|
||||
if self.args.benchmark:
|
||||
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -2411,6 +2476,7 @@ class FunctionOptimizer:
|
|||
coverage_results=coverage_results,
|
||||
line_profile_results=line_profile_results,
|
||||
async_throughput=async_throughput,
|
||||
concurrency_metrics=concurrency_metrics,
|
||||
),
|
||||
functions_to_remove,
|
||||
)
|
||||
|
|
@ -2607,12 +2673,23 @@ class FunctionOptimizer:
|
|||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
|
||||
candidate_async_throughput = None
|
||||
candidate_concurrency_metrics = None
|
||||
if self.function_to_optimize.is_async:
|
||||
candidate_async_throughput = calculate_function_throughput_from_test_results(
|
||||
candidate_benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
|
||||
|
||||
# Run concurrency benchmark for candidate
|
||||
candidate_concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=candidate_helper_code, test_env=test_env
|
||||
)
|
||||
if candidate_concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Candidate concurrency metrics: ratio={candidate_concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={candidate_concurrency_metrics.sequential_time_ns}ns, conc={candidate_concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
|
||||
if self.args.benchmark:
|
||||
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -2633,6 +2710,7 @@ class FunctionOptimizer:
|
|||
optimization_candidate_index=optimization_candidate_index,
|
||||
total_candidate_timing=total_candidate_timing,
|
||||
async_throughput=candidate_async_throughput,
|
||||
concurrency_metrics=candidate_concurrency_metrics,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -2912,3 +2990,57 @@ class FunctionOptimizer:
|
|||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
return line_profile_results
|
||||
|
||||
def run_concurrency_benchmark(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]
|
||||
) -> ConcurrencyMetrics | None:
|
||||
"""Run concurrency benchmark to measure sequential vs concurrent execution for async functions.
|
||||
|
||||
This benchmark detects blocking vs non-blocking async code by comparing:
|
||||
- Sequential execution time (running N iterations one after another)
|
||||
- Concurrent execution time (running N iterations in parallel with asyncio.gather)
|
||||
|
||||
Blocking code (like time.sleep) will have similar sequential and concurrent times.
|
||||
Non-blocking code (like asyncio.sleep) will be much faster when run concurrently.
|
||||
|
||||
Returns:
|
||||
ConcurrencyMetrics if benchmark ran successfully, None otherwise.
|
||||
|
||||
"""
|
||||
if not self.function_to_optimize.is_async:
|
||||
return None
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
try:
|
||||
# Add concurrency decorator to the source function
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
|
||||
)
|
||||
|
||||
# Run the concurrency benchmark tests
|
||||
concurrency_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE, # Use performance mode for running
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=5.0, # Short benchmark time
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=3,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Concurrency benchmark failed: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Restore original code
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
|
||||
# Parse concurrency metrics from stdout
|
||||
if concurrency_results and concurrency_results.perf_stdout:
|
||||
return parse_concurrency_metrics(concurrency_results, self.function_to_optimize.function_name)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.config_consts import (
|
||||
COVERAGE_THRESHOLD,
|
||||
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD,
|
||||
MIN_IMPROVEMENT_THRESHOLD,
|
||||
MIN_TESTCASE_PASSED_THRESHOLD,
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
|
||||
|
|
@ -12,7 +14,14 @@ from codeflash.code_utils.config_consts import (
|
|||
from codeflash.models import models
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
from codeflash.models.models import ConcurrencyMetrics, CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
|
||||
|
||||
class AcceptanceReason(Enum):
|
||||
RUNTIME = "runtime"
|
||||
THROUGHPUT = "throughput"
|
||||
CONCURRENCY = "concurrency"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
|
||||
|
|
@ -36,6 +45,22 @@ def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> f
|
|||
return (optimized_throughput - original_throughput) / original_throughput
|
||||
|
||||
|
||||
def concurrency_gain(original_metrics: ConcurrencyMetrics, optimized_metrics: ConcurrencyMetrics) -> float:
|
||||
"""Calculate concurrency ratio improvement.
|
||||
|
||||
Returns the relative improvement in concurrency ratio.
|
||||
Higher is better - means the optimized code scales better with concurrent execution.
|
||||
|
||||
concurrency_ratio = sequential_time / concurrent_time
|
||||
A ratio of 10 means concurrent execution is 10x faster than sequential.
|
||||
"""
|
||||
if original_metrics.concurrency_ratio == 0:
|
||||
return 0.0
|
||||
return (
|
||||
optimized_metrics.concurrency_ratio - original_metrics.concurrency_ratio
|
||||
) / original_metrics.concurrency_ratio
|
||||
|
||||
|
||||
def speedup_critic(
|
||||
candidate_result: OptimizedCandidateResult,
|
||||
original_code_runtime: int,
|
||||
|
|
@ -44,10 +69,12 @@ def speedup_critic(
|
|||
disable_gh_action_noise: bool = False,
|
||||
original_async_throughput: int | None = None,
|
||||
best_throughput_until_now: int | None = None,
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
best_concurrency_ratio_until_now: float | None = None,
|
||||
) -> bool:
|
||||
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
|
||||
|
||||
Evaluates both runtime performance and async throughput improvements.
|
||||
Evaluates runtime performance, async throughput, and concurrency improvements.
|
||||
|
||||
For runtime performance:
|
||||
- Ensures the optimization is actually faster than the original code, above the noise floor.
|
||||
|
|
@ -58,6 +85,10 @@ def speedup_critic(
|
|||
For async throughput (when available):
|
||||
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
- Throughput improvements complement runtime improvements for async functions
|
||||
|
||||
For concurrency (when available):
|
||||
- Evaluates concurrency ratio improvements using MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
|
||||
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
|
||||
"""
|
||||
# Runtime performance evaluation
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
|
|
@ -86,14 +117,78 @@ def speedup_critic(
|
|||
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
|
||||
)
|
||||
|
||||
# Concurrency evaluation
|
||||
concurrency_improved = False
|
||||
concurrency_is_best = True
|
||||
if original_concurrency_metrics is not None and candidate_result.concurrency_metrics is not None:
|
||||
conc_gain = concurrency_gain(original_concurrency_metrics, candidate_result.concurrency_metrics)
|
||||
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
|
||||
concurrency_is_best = (
|
||||
best_concurrency_ratio_until_now is None
|
||||
or candidate_result.concurrency_metrics.concurrency_ratio > best_concurrency_ratio_until_now
|
||||
)
|
||||
|
||||
# Accept if ANY of: runtime, throughput, or concurrency improves significantly
|
||||
if original_async_throughput is not None and candidate_result.async_throughput is not None:
|
||||
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
|
||||
throughput_acceptance = throughput_improved and throughput_is_best
|
||||
runtime_acceptance = runtime_improved and runtime_is_best
|
||||
return throughput_acceptance or runtime_acceptance
|
||||
concurrency_acceptance = concurrency_improved and concurrency_is_best
|
||||
return throughput_acceptance or runtime_acceptance or concurrency_acceptance
|
||||
return runtime_improved and runtime_is_best
|
||||
|
||||
|
||||
def get_acceptance_reason(
|
||||
original_runtime_ns: int,
|
||||
optimized_runtime_ns: int,
|
||||
*,
|
||||
original_async_throughput: int | None = None,
|
||||
optimized_async_throughput: int | None = None,
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
) -> AcceptanceReason:
|
||||
"""Determine why an optimization was accepted.
|
||||
|
||||
Returns the primary reason for acceptance, with priority:
|
||||
concurrency > throughput > runtime (for async code).
|
||||
"""
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
if env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2
|
||||
|
||||
perf_gain = performance_gain(original_runtime_ns=original_runtime_ns, optimized_runtime_ns=optimized_runtime_ns)
|
||||
runtime_improved = perf_gain > noise_floor
|
||||
|
||||
throughput_improved = False
|
||||
if (
|
||||
original_async_throughput is not None
|
||||
and optimized_async_throughput is not None
|
||||
and original_async_throughput > 0
|
||||
):
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_async_throughput, optimized_throughput=optimized_async_throughput
|
||||
)
|
||||
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
|
||||
concurrency_improved = False
|
||||
if original_concurrency_metrics is not None and optimized_concurrency_metrics is not None:
|
||||
conc_gain = concurrency_gain(original_concurrency_metrics, optimized_concurrency_metrics)
|
||||
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
|
||||
|
||||
# Return reason with priority: concurrency > throughput > runtime
|
||||
if original_async_throughput is not None and optimized_async_throughput is not None:
|
||||
if concurrency_improved:
|
||||
return AcceptanceReason.CONCURRENCY
|
||||
if throughput_improved:
|
||||
return AcceptanceReason.THROUGHPUT
|
||||
if runtime_improved:
|
||||
return AcceptanceReason.RUNTIME
|
||||
return AcceptanceReason.NONE
|
||||
|
||||
if runtime_improved:
|
||||
return AcceptanceReason.RUNTIME
|
||||
return AcceptanceReason.NONE
|
||||
|
||||
|
||||
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
|
||||
test_results = candidate_result.behavior_test_results
|
||||
report = test_results.get_test_pass_fail_report_by_type()
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from rich.table import Table
|
|||
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import BenchmarkDetail, TestResults
|
||||
from codeflash.result.critic import throughput_gain
|
||||
from codeflash.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults
|
||||
from codeflash.result.critic import AcceptanceReason, concurrency_gain, throughput_gain
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
|
|
@ -27,31 +27,44 @@ class Explanation:
|
|||
benchmark_details: Optional[list[BenchmarkDetail]] = None
|
||||
original_async_throughput: Optional[int] = None
|
||||
best_async_throughput: Optional[int] = None
|
||||
original_concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
best_concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
acceptance_reason: AcceptanceReason = AcceptanceReason.RUNTIME
|
||||
|
||||
@property
|
||||
def perf_improvement_line(self) -> str:
|
||||
# speedup property already handles choosing between runtime and throughput
|
||||
improvement_type = {
|
||||
AcceptanceReason.RUNTIME: "runtime",
|
||||
AcceptanceReason.THROUGHPUT: "throughput",
|
||||
AcceptanceReason.CONCURRENCY: "concurrency",
|
||||
AcceptanceReason.NONE: "",
|
||||
}.get(self.acceptance_reason, "")
|
||||
|
||||
if improvement_type:
|
||||
return f"{self.speedup_pct} {improvement_type} improvement ({self.speedup_x} faster)."
|
||||
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
|
||||
|
||||
@property
|
||||
def speedup(self) -> float:
|
||||
runtime_improvement = (self.original_runtime_ns / self.best_runtime_ns) - 1
|
||||
|
||||
# Use throughput improvement if we have async metrics and throughput is better
|
||||
"""Returns the improvement value for the metric that caused acceptance."""
|
||||
if (
|
||||
self.original_async_throughput is not None
|
||||
self.acceptance_reason == AcceptanceReason.CONCURRENCY
|
||||
and self.original_concurrency_metrics
|
||||
and self.best_concurrency_metrics
|
||||
):
|
||||
return concurrency_gain(self.original_concurrency_metrics, self.best_concurrency_metrics)
|
||||
|
||||
if (
|
||||
self.acceptance_reason == AcceptanceReason.THROUGHPUT
|
||||
and self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
and self.original_async_throughput > 0
|
||||
):
|
||||
throughput_improvement = throughput_gain(
|
||||
return throughput_gain(
|
||||
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
|
||||
)
|
||||
|
||||
# Use throughput metrics if throughput improvement is better or runtime got worse
|
||||
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
|
||||
return throughput_improvement
|
||||
|
||||
return runtime_improvement
|
||||
return (self.original_runtime_ns / self.best_runtime_ns) - 1
|
||||
|
||||
@property
|
||||
def speedup_x(self) -> str:
|
||||
|
|
@ -108,7 +121,22 @@ class Explanation:
|
|||
console.print(table)
|
||||
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
|
||||
|
||||
if self.original_async_throughput is not None and self.best_async_throughput is not None:
|
||||
if (
|
||||
self.acceptance_reason == AcceptanceReason.CONCURRENCY
|
||||
and self.original_concurrency_metrics
|
||||
and self.best_concurrency_metrics
|
||||
):
|
||||
orig_ratio = self.original_concurrency_metrics.concurrency_ratio
|
||||
best_ratio = self.best_concurrency_metrics.concurrency_ratio
|
||||
performance_description = (
|
||||
f"Concurrency ratio improved from {orig_ratio:.2f}x to {best_ratio:.2f}x "
|
||||
f"(concurrent execution now {best_ratio:.2f}x faster than sequential)\n\n"
|
||||
)
|
||||
elif (
|
||||
self.acceptance_reason == AcceptanceReason.THROUGHPUT
|
||||
and self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
):
|
||||
performance_description = (
|
||||
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
|
||||
f"(runtime: {original_runtime_human} → {best_runtime_human})\n\n"
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ HAS_TORCH = find_spec("torch") is not None
|
|||
HAS_JAX = find_spec("jax") is not None
|
||||
HAS_XARRAY = find_spec("xarray") is not None
|
||||
HAS_TENSORFLOW = find_spec("tensorflow") is not None
|
||||
HAS_NUMBA = find_spec("numba") is not None
|
||||
|
||||
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
|
||||
# These paths vary between test runs but are logically equivalent
|
||||
|
|
@ -156,6 +157,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
range,
|
||||
slice,
|
||||
OrderedDict,
|
||||
types.GenericAlias,
|
||||
),
|
||||
):
|
||||
return orig == new
|
||||
|
|
@ -296,8 +298,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
# fails at "ufunc 'isfinite' not supported for the input types"
|
||||
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
|
||||
|
||||
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
|
||||
return np.isclose(orig, new)
|
||||
if isinstance(orig, (np.floating, np.complexfloating)):
|
||||
return np.isclose(orig, new, equal_nan=True)
|
||||
|
||||
if isinstance(orig, (np.integer, np.bool_, np.byte)):
|
||||
return orig == new
|
||||
|
|
@ -383,6 +385,42 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
if isinstance(orig, torch.device):
|
||||
return orig == new
|
||||
|
||||
if HAS_NUMBA:
|
||||
import numba # type: ignore # noqa: PGH003
|
||||
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
|
||||
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
|
||||
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
|
||||
|
||||
# Handle numba typed List
|
||||
if isinstance(orig, NumbaList):
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||
|
||||
# Handle numba typed Dict
|
||||
if isinstance(orig, NumbaDict):
|
||||
if superset_obj:
|
||||
# Allow new dict to have more keys, but all orig keys must exist with equal values
|
||||
return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig)
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
for key in orig:
|
||||
if key not in new:
|
||||
return False
|
||||
if not comparator(orig[key], new[key], superset_obj):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Handle numba type objects (e.g., numba.int64, numba.float64, numba.Array, etc.)
|
||||
if isinstance(orig, numba.core.types.Type):
|
||||
return orig == new
|
||||
|
||||
# Handle numba JIT-compiled functions (CPUDispatcher, etc.)
|
||||
if isinstance(orig, Dispatcher):
|
||||
# Compare by identity of the underlying Python function
|
||||
# Two JIT functions are equal if they wrap the same Python function
|
||||
return orig.py_func is new.py_func
|
||||
|
||||
if HAS_PYRSISTENT:
|
||||
import pyrsistent # type: ignore # noqa: PGH003
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,14 @@ from codeflash.code_utils.code_utils import (
|
|||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
|
||||
from codeflash.models.models import (
|
||||
ConcurrencyMetrics,
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
TestType,
|
||||
VerificationType,
|
||||
)
|
||||
from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -70,6 +77,54 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
|
|||
return function_throughput
|
||||
|
||||
|
||||
# Pattern for concurrency benchmark output:
|
||||
# !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
|
||||
_concurrency_pattern = re.compile(r"!@######CONC:([^:]*):([^:]*):([^:]*):([^:]*):([^:]*):(\d+):(\d+):(\d+)######@!")
|
||||
|
||||
|
||||
def parse_concurrency_metrics(test_results: TestResults, function_name: str) -> ConcurrencyMetrics | None:
|
||||
"""Parse concurrency benchmark results from test output.
|
||||
|
||||
Format: !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
|
||||
|
||||
Returns ConcurrencyMetrics with:
|
||||
- sequential_time_ns: Total time for N sequential executions
|
||||
- concurrent_time_ns: Total time for N concurrent executions
|
||||
- concurrency_factor: N (number of concurrent executions)
|
||||
- concurrency_ratio: sequential_time / concurrent_time (higher = better concurrency)
|
||||
"""
|
||||
if not test_results.perf_stdout:
|
||||
return None
|
||||
|
||||
matches = _concurrency_pattern.findall(test_results.perf_stdout)
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Aggregate metrics for the target function
|
||||
total_seq, total_conc, factor, count = 0, 0, 0, 0
|
||||
for match in matches:
|
||||
# match[3] is function_name
|
||||
if len(match) >= 8 and match[3] == function_name:
|
||||
total_seq += int(match[5])
|
||||
total_conc += int(match[6])
|
||||
factor = int(match[7])
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
avg_seq = total_seq / count
|
||||
avg_conc = total_conc / count
|
||||
ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0
|
||||
|
||||
return ConcurrencyMetrics(
|
||||
sequential_time_ns=int(avg_seq),
|
||||
concurrent_time_ns=int(avg_conc),
|
||||
concurrency_factor=factor,
|
||||
concurrency_ratio=ratio,
|
||||
)
|
||||
|
||||
|
||||
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
|
||||
"""Resolve test file path from pytest's test class path.
|
||||
|
||||
|
|
@ -412,7 +467,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# Default to GENERATED_REGRESSION for Jest tests when test type can't be determined
|
||||
if test_type is None and is_jest:
|
||||
test_type = TestType.GENERATED_REGRESSION
|
||||
logger.debug(f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION (Jest)")
|
||||
logger.debug("[PARSE-DEBUG] defaulting to GENERATED_REGRESSION (Jest)")
|
||||
elif test_type is None:
|
||||
# Skip results where test type cannot be determined
|
||||
logger.debug(f"Skipping result for {test_function_name}: could not determine test type")
|
||||
|
|
@ -504,7 +559,7 @@ def _extract_jest_console_output(suite_elem) -> str:
|
|||
|
||||
return raw_content
|
||||
|
||||
# ToDO: {Claude} we need to move to the support directory.
|
||||
# TODO: {Claude} we need to move to the support directory.
|
||||
def parse_jest_test_xml(
|
||||
test_xml_file_path: Path,
|
||||
test_files: TestFiles,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
config = TestConfig(
|
||||
file_path="main.py",
|
||||
min_improvement_x=0.1,
|
||||
expected_acceptance_reason="concurrency",
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class TestConfig:
|
|||
benchmarks_root: Optional[pathlib.Path] = None
|
||||
use_worktree: bool = False
|
||||
no_gen_tests: bool = False
|
||||
expected_acceptance_reason: Optional[str] = None # "runtime", "throughput", "concurrency"
|
||||
|
||||
|
||||
def clear_directory(directory_path: str | pathlib.Path) -> None:
|
||||
|
|
@ -176,7 +177,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
|
|||
logging.error("Failed to find performance improvement message")
|
||||
return False
|
||||
|
||||
improvement_match = re.search(r"📈 ([\d,]+)% improvement", stdout)
|
||||
improvement_match = re.search(r"📈 ([\d,]+)% (?:(\w+) )?improvement", stdout)
|
||||
if not improvement_match:
|
||||
logging.error("Could not find improvement percentage in output")
|
||||
return False
|
||||
|
|
@ -193,6 +194,15 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
|
|||
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
|
||||
return False
|
||||
|
||||
if config.expected_acceptance_reason is not None:
|
||||
actual_reason = improvement_match.group(2)
|
||||
if not actual_reason:
|
||||
logging.error("Could not find acceptance reason type in output")
|
||||
return False
|
||||
if actual_reason != config.expected_acceptance_reason:
|
||||
logging.error(f"Expected acceptance reason '{config.expected_acceptance_reason}', got '{actual_reason}'")
|
||||
return False
|
||||
|
||||
if config.expected_unit_tests_count is not None:
|
||||
# Match the global test discovery message from optimizer.py which counts test invocations
|
||||
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"
|
||||
|
|
|
|||
304
tests/test_async_concurrency_decorator.py
Normal file
304
tests/test_async_concurrency_decorator.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_concurrency_async
|
||||
from codeflash.models.models import ConcurrencyMetrics, TestResults
|
||||
from codeflash.verification.parse_test_output import parse_concurrency_metrics
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestConcurrencyAsyncDecorator:
|
||||
"""Integration tests for codeflash_concurrency_async decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def concurrency_env_setup(self, request):
|
||||
"""Set up environment variables for concurrency testing."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestConcurrencyAsyncDecorator",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CONCURRENCY_FACTOR": "5", # Use smaller factor for faster tests
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_nonblocking_function(self, concurrency_env_setup, capsys):
|
||||
"""Test that non-blocking async functions show high concurrency ratio."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def nonblocking_sleep(duration: float) -> str:
|
||||
await asyncio.sleep(duration)
|
||||
return "done"
|
||||
|
||||
result = await nonblocking_sleep(0.01)
|
||||
|
||||
assert result == "done"
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Verify the output format
|
||||
assert "!@######CONC:" in output
|
||||
assert "######@!" in output
|
||||
|
||||
# Parse the output manually to verify format
|
||||
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
|
||||
assert len(lines) == 1
|
||||
|
||||
line = lines[0]
|
||||
# Format: !@######CONC:{test_module}:{test_class}:{test_function}:{function_name}:{loop_index}:{seq_time}:{conc_time}:{factor}######@!
|
||||
assert "nonblocking_sleep" in line
|
||||
assert ":5######@!" in line # concurrency factor
|
||||
|
||||
# Extract timing values
|
||||
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
# parts should be: [test_module, test_class, test_function, function_name, loop_index, seq_time, conc_time, factor]
|
||||
assert len(parts) == 8
|
||||
|
||||
seq_time = int(parts[5])
|
||||
conc_time = int(parts[6])
|
||||
factor = int(parts[7])
|
||||
|
||||
assert seq_time > 0
|
||||
assert conc_time > 0
|
||||
assert factor == 5
|
||||
|
||||
# For non-blocking async, concurrent time should be much less than sequential
|
||||
# Sequential runs 5 iterations of 10ms = ~50ms
|
||||
# Concurrent runs 5 iterations in parallel = ~10ms
|
||||
# So ratio should be around 5 (with some overhead tolerance)
|
||||
ratio = seq_time / conc_time if conc_time > 0 else 1.0
|
||||
assert ratio > 2.0, f"Non-blocking function should have ratio > 2.0, got {ratio}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_blocking_function(self, concurrency_env_setup, capsys):
|
||||
"""Test that blocking functions show low concurrency ratio (~1.0)."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def blocking_sleep(duration: float) -> str:
|
||||
time.sleep(duration) # Blocking sleep
|
||||
return "done"
|
||||
|
||||
result = await blocking_sleep(0.005) # 5ms blocking
|
||||
|
||||
assert result == "done"
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
assert "!@######CONC:" in output
|
||||
|
||||
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
|
||||
assert len(lines) == 1
|
||||
|
||||
line = lines[0]
|
||||
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
assert len(parts) == 8
|
||||
|
||||
seq_time = int(parts[5])
|
||||
conc_time = int(parts[6])
|
||||
|
||||
# For blocking code, sequential and concurrent times should be similar
|
||||
# Because time.sleep blocks the entire event loop
|
||||
ratio = seq_time / conc_time if conc_time > 0 else 1.0
|
||||
# Blocking code should have ratio close to 1.0 (within reasonable tolerance)
|
||||
assert ratio < 2.0, f"Blocking function should have ratio < 2.0, got {ratio}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_with_computation(self, concurrency_env_setup, capsys):
|
||||
"""Test concurrency with CPU-bound computation."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def compute_intensive(n: int) -> int:
|
||||
# CPU-bound work (blocked by GIL in concurrent execution)
|
||||
total = 0
|
||||
for i in range(n):
|
||||
total += i * i
|
||||
return total
|
||||
|
||||
result = await compute_intensive(10000)
|
||||
|
||||
assert result == sum(i * i for i in range(10000))
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
assert "!@######CONC:" in output
|
||||
assert "compute_intensive" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestParseConcurrencyMetrics:
|
||||
"""Integration tests for parse_concurrency_metrics function."""
|
||||
|
||||
def test_parse_concurrency_metrics_from_real_output(self):
|
||||
"""Test parsing concurrency metrics from simulated stdout."""
|
||||
# Simulate stdout from codeflash_concurrency_async decorator
|
||||
perf_stdout = """Some other output
|
||||
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
|
||||
More output here
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_async_func")
|
||||
|
||||
assert metrics is not None
|
||||
assert isinstance(metrics, ConcurrencyMetrics)
|
||||
assert metrics.sequential_time_ns == 50000000
|
||||
assert metrics.concurrent_time_ns == 10000000
|
||||
assert metrics.concurrency_factor == 5
|
||||
assert metrics.concurrency_ratio == 5.0 # 50M / 10M = 5.0
|
||||
|
||||
def test_parse_concurrency_metrics_multiple_entries(self):
|
||||
"""Test parsing when multiple concurrency entries exist."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:target_func:1:40000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "target_func")
|
||||
|
||||
assert metrics is not None
|
||||
# Should average the two entries for target_func
|
||||
# (40M + 60M) / 2 = 50M seq, (10M + 10M) / 2 = 10M conc
|
||||
assert metrics.sequential_time_ns == 50000000
|
||||
assert metrics.concurrent_time_ns == 10000000
|
||||
assert metrics.concurrency_ratio == 5.0
|
||||
|
||||
def test_parse_concurrency_metrics_no_match(self):
|
||||
"""Test parsing when function name doesn't match."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
def test_parse_concurrency_metrics_empty_stdout(self):
|
||||
"""Test parsing with empty stdout."""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout="",
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
def test_parse_concurrency_metrics_none_stdout(self):
|
||||
"""Test parsing with None stdout."""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=None,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestConcurrencyRatioComparison:
|
||||
"""Test comparing blocking vs non-blocking concurrency ratios."""
|
||||
|
||||
@pytest.fixture
|
||||
def comparison_env_setup(self, request):
|
||||
"""Set up environment variables for comparison testing."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestConcurrencyRatioComparison",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CONCURRENCY_FACTOR": "10",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_vs_nonblocking_comparison(self, comparison_env_setup, capsys):
|
||||
"""Compare concurrency ratios between blocking and non-blocking implementations."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def blocking_impl() -> str:
|
||||
time.sleep(0.002) # 2ms blocking
|
||||
return "blocking"
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def nonblocking_impl() -> str:
|
||||
await asyncio.sleep(0.002) # 2ms non-blocking
|
||||
return "nonblocking"
|
||||
|
||||
# Run blocking version
|
||||
await blocking_impl()
|
||||
blocking_output = capsys.readouterr().out
|
||||
|
||||
# Run non-blocking version
|
||||
await nonblocking_impl()
|
||||
nonblocking_output = capsys.readouterr().out
|
||||
|
||||
# Parse blocking metrics
|
||||
blocking_line = [l for l in blocking_output.split("\n") if "!@######CONC:" in l][0]
|
||||
blocking_parts = blocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
blocking_seq = int(blocking_parts[5])
|
||||
blocking_conc = int(blocking_parts[6])
|
||||
blocking_ratio = blocking_seq / blocking_conc if blocking_conc > 0 else 1.0
|
||||
|
||||
# Parse non-blocking metrics
|
||||
nonblocking_line = [l for l in nonblocking_output.split("\n") if "!@######CONC:" in l][0]
|
||||
nonblocking_parts = nonblocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
nonblocking_seq = int(nonblocking_parts[5])
|
||||
nonblocking_conc = int(nonblocking_parts[6])
|
||||
nonblocking_ratio = nonblocking_seq / nonblocking_conc if nonblocking_conc > 0 else 1.0
|
||||
|
||||
# Non-blocking should have significantly higher concurrency ratio
|
||||
assert nonblocking_ratio > blocking_ratio, (
|
||||
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than "
|
||||
f"blocking ratio ({blocking_ratio:.2f})"
|
||||
)
|
||||
|
||||
# The difference should be substantial (non-blocking should be at least 2x better)
|
||||
ratio_improvement = nonblocking_ratio / blocking_ratio if blocking_ratio > 0 else 0
|
||||
assert ratio_improvement > 2.0, (
|
||||
f"Non-blocking should show >2x improvement in concurrency ratio, got {ratio_improvement:.2f}x"
|
||||
)
|
||||
|
|
@ -17,7 +17,14 @@ import pytest
|
|||
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
from codeflash.verification.comparator import comparator, _extract_exception_from_message, _get_wrapped_exception
|
||||
from codeflash.verification.comparator import (
|
||||
PYTEST_TEMP_PATH_PATTERN,
|
||||
_extract_exception_from_message,
|
||||
_get_wrapped_exception,
|
||||
_is_temp_path,
|
||||
_normalize_temp_path,
|
||||
comparator,
|
||||
)
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
|
||||
|
|
@ -2911,16 +2918,378 @@ def test_numpy_dtypes() -> None:
|
|||
assert not comparator(dtypes.Int32DType(), np.dtype('float32'))
|
||||
|
||||
|
||||
def test_numpy_extended_precision_types() -> None:
|
||||
"""Test comparator for numpy extended precision types like clongdouble."""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pytest.skip("numpy not available")
|
||||
|
||||
# Test np.clongdouble (extended precision complex)
|
||||
c1 = np.clongdouble(1 + 2j)
|
||||
c2 = np.clongdouble(1 + 2j)
|
||||
c3 = np.clongdouble(1 + 3j)
|
||||
assert comparator(c1, c2)
|
||||
assert not comparator(c1, c3)
|
||||
|
||||
# Test np.longdouble (extended precision float)
|
||||
l1 = np.longdouble(1.5)
|
||||
l2 = np.longdouble(1.5)
|
||||
l3 = np.longdouble(2.5)
|
||||
assert comparator(l1, l2)
|
||||
assert not comparator(l1, l3)
|
||||
|
||||
# Test NaN handling for extended precision complex
|
||||
nan_c1 = np.clongdouble(complex(np.nan, 2))
|
||||
nan_c2 = np.clongdouble(complex(np.nan, 2))
|
||||
assert comparator(nan_c1, nan_c2)
|
||||
|
||||
# Test NaN handling for extended precision float
|
||||
nan_l1 = np.longdouble(np.nan)
|
||||
nan_l2 = np.longdouble(np.nan)
|
||||
assert comparator(nan_l1, nan_l2)
|
||||
|
||||
|
||||
def test_numpy_typing_types() -> None:
|
||||
"""Test comparator for numpy.typing types like NDArray type aliases."""
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
except ImportError:
|
||||
pytest.skip("numpy or numpy.typing not available")
|
||||
|
||||
# Test NDArray type alias comparisons
|
||||
arr_type1 = npt.NDArray[np.float64]
|
||||
arr_type2 = npt.NDArray[np.float64]
|
||||
arr_type3 = npt.NDArray[np.int64]
|
||||
assert comparator(arr_type1, arr_type2)
|
||||
assert not comparator(arr_type1, arr_type3)
|
||||
|
||||
# Test NBitBase (if it can be instantiated)
|
||||
try:
|
||||
nbit1 = npt.NBitBase()
|
||||
nbit2 = npt.NBitBase()
|
||||
# NBitBase instances with empty __dict__ should compare as equal
|
||||
assert comparator(nbit1, nbit2)
|
||||
# Also test with superset_obj=True
|
||||
assert comparator(nbit1, nbit2, superset_obj=True)
|
||||
except TypeError:
|
||||
# NBitBase may not be instantiable in all numpy versions
|
||||
pass
|
||||
|
||||
|
||||
def test_numpy_typing_superset_obj() -> None:
|
||||
"""Test comparator with superset_obj=True for numpy types."""
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
except ImportError:
|
||||
pytest.skip("numpy or numpy.typing not available")
|
||||
|
||||
# Test numpy arrays with object dtype containing dicts (superset scenario)
|
||||
a1 = np.array([{'a': 1}], dtype=object)
|
||||
a2 = np.array([{'a': 1, 'b': 2}], dtype=object) # superset
|
||||
assert comparator(a1, a2, superset_obj=True)
|
||||
assert not comparator(a1, a2, superset_obj=False)
|
||||
|
||||
# Test extended precision types with superset_obj=True
|
||||
c1 = np.clongdouble(1 + 2j)
|
||||
c2 = np.clongdouble(1 + 2j)
|
||||
assert comparator(c1, c2, superset_obj=True)
|
||||
|
||||
l1 = np.longdouble(1.5)
|
||||
l2 = np.longdouble(1.5)
|
||||
assert comparator(l1, l2, superset_obj=True)
|
||||
|
||||
# Test NDArray type alias with superset_obj=True
|
||||
arr_type1 = npt.NDArray[np.float64]
|
||||
arr_type2 = npt.NDArray[np.float64]
|
||||
assert comparator(arr_type1, arr_type2, superset_obj=True)
|
||||
|
||||
# Test numpy structured arrays (np.void) with superset_obj=True
|
||||
dt = np.dtype([('name', 'S10'), ('age', np.int32)])
|
||||
a_struct = np.array([('Alice', 25)], dtype=dt)
|
||||
b_struct = np.array([('Alice', 25)], dtype=dt)
|
||||
assert comparator(a_struct[0], b_struct[0], superset_obj=True)
|
||||
|
||||
# Test numpy random generators with superset_obj=True
|
||||
rng1 = np.random.default_rng(seed=42)
|
||||
rng2 = np.random.default_rng(seed=42)
|
||||
assert comparator(rng1, rng2, superset_obj=True)
|
||||
|
||||
rs1 = np.random.RandomState(seed=42)
|
||||
rs2 = np.random.RandomState(seed=42)
|
||||
assert comparator(rs1, rs2, superset_obj=True)
|
||||
def test_numba_typed_list() -> None:
|
||||
"""Test comparator for numba.typed.List."""
|
||||
try:
|
||||
import numba
|
||||
from numba.typed import List as NumbaList
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test equal lists
|
||||
a = NumbaList([1, 2, 3])
|
||||
b = NumbaList([1, 2, 3])
|
||||
assert comparator(a, b)
|
||||
|
||||
# Test different values
|
||||
c = NumbaList([1, 2, 4])
|
||||
assert not comparator(a, c)
|
||||
|
||||
# Test different lengths
|
||||
d = NumbaList([1, 2, 3, 4])
|
||||
assert not comparator(a, d)
|
||||
|
||||
# Test empty lists
|
||||
e = NumbaList.empty_list(item_type=numba.int64)
|
||||
f = NumbaList.empty_list(item_type=numba.int64)
|
||||
assert comparator(e, f)
|
||||
|
||||
# Test nested values (floats)
|
||||
g = NumbaList([1.0, 2.0, 3.0])
|
||||
h = NumbaList([1.0, 2.0, 3.0])
|
||||
assert comparator(g, h)
|
||||
|
||||
i = NumbaList([1.0, 2.0, 4.0])
|
||||
assert not comparator(g, i)
|
||||
|
||||
|
||||
def test_numba_typed_dict() -> None:
|
||||
"""Test comparator for numba.typed.Dict."""
|
||||
try:
|
||||
import numba
|
||||
from numba.typed import Dict as NumbaDict
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test equal dicts
|
||||
a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
a["x"] = 1
|
||||
a["y"] = 2
|
||||
|
||||
b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
b["x"] = 1
|
||||
b["y"] = 2
|
||||
assert comparator(a, b)
|
||||
|
||||
# Test different values
|
||||
c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
c["x"] = 1
|
||||
c["y"] = 3
|
||||
assert not comparator(a, c)
|
||||
|
||||
# Test different keys
|
||||
d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
d["x"] = 1
|
||||
d["z"] = 2
|
||||
assert not comparator(a, d)
|
||||
|
||||
# Test different lengths
|
||||
e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
e["x"] = 1
|
||||
assert not comparator(a, e)
|
||||
|
||||
# Test empty dicts
|
||||
f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
assert comparator(f, g)
|
||||
|
||||
|
||||
def test_numba_types() -> None:
|
||||
"""Test comparator for numba type objects."""
|
||||
try:
|
||||
import numba
|
||||
from numba import types
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test basic numeric types from numba module
|
||||
assert comparator(numba.int64, numba.int64)
|
||||
assert comparator(numba.float64, numba.float64)
|
||||
assert comparator(numba.int32, numba.int32)
|
||||
assert comparator(numba.float32, numba.float32)
|
||||
|
||||
# Test basic numeric types from numba.types module
|
||||
assert comparator(types.int64, types.int64)
|
||||
assert comparator(types.float64, types.float64)
|
||||
assert comparator(types.int8, types.int8)
|
||||
assert comparator(types.int16, types.int16)
|
||||
assert comparator(types.uint8, types.uint8)
|
||||
assert comparator(types.uint16, types.uint16)
|
||||
assert comparator(types.uint32, types.uint32)
|
||||
assert comparator(types.uint64, types.uint64)
|
||||
assert comparator(types.complex64, types.complex64)
|
||||
assert comparator(types.complex128, types.complex128)
|
||||
|
||||
# Test different types
|
||||
assert not comparator(numba.int64, numba.float64)
|
||||
assert not comparator(numba.int32, numba.int64)
|
||||
assert not comparator(numba.float32, numba.float64)
|
||||
assert not comparator(types.int8, types.int16)
|
||||
assert not comparator(types.uint32, types.int32)
|
||||
assert not comparator(types.complex64, types.complex128)
|
||||
|
||||
# Test boolean type
|
||||
assert comparator(numba.boolean, numba.boolean)
|
||||
assert comparator(types.boolean, types.boolean)
|
||||
assert not comparator(numba.boolean, numba.int64)
|
||||
|
||||
# Test special types
|
||||
assert comparator(types.none, types.none)
|
||||
assert comparator(types.void, types.void)
|
||||
assert comparator(types.pyobject, types.pyobject)
|
||||
assert comparator(types.unicode_type, types.unicode_type)
|
||||
# Note: types.none and types.void are the same object in numba
|
||||
assert comparator(types.none, types.void)
|
||||
assert not comparator(types.unicode_type, types.pyobject)
|
||||
assert not comparator(types.none, types.int64)
|
||||
|
||||
# Test array types
|
||||
arr_type1 = types.Array(numba.float64, 1, 'C')
|
||||
arr_type2 = types.Array(numba.float64, 1, 'C')
|
||||
arr_type3 = types.Array(numba.float64, 2, 'C')
|
||||
arr_type4 = types.Array(numba.int64, 1, 'C')
|
||||
arr_type5 = types.Array(numba.float64, 1, 'F') # Fortran order
|
||||
|
||||
assert comparator(arr_type1, arr_type2)
|
||||
assert not comparator(arr_type1, arr_type3) # different ndim
|
||||
assert not comparator(arr_type1, arr_type4) # different dtype
|
||||
assert not comparator(arr_type1, arr_type5) # different layout
|
||||
|
||||
# Test tuple types
|
||||
tuple_type1 = types.UniTuple(types.int64, 3)
|
||||
tuple_type2 = types.UniTuple(types.int64, 3)
|
||||
tuple_type3 = types.UniTuple(types.int64, 4)
|
||||
tuple_type4 = types.UniTuple(types.float64, 3)
|
||||
|
||||
assert comparator(tuple_type1, tuple_type2)
|
||||
assert not comparator(tuple_type1, tuple_type3) # different count
|
||||
assert not comparator(tuple_type1, tuple_type4) # different dtype
|
||||
|
||||
# Test heterogeneous tuple types
|
||||
hetero_tuple1 = types.Tuple([types.int64, types.float64])
|
||||
hetero_tuple2 = types.Tuple([types.int64, types.float64])
|
||||
hetero_tuple3 = types.Tuple([types.int64, types.int64])
|
||||
|
||||
assert comparator(hetero_tuple1, hetero_tuple2)
|
||||
assert not comparator(hetero_tuple1, hetero_tuple3)
|
||||
|
||||
# Test ListType and DictType
|
||||
list_type1 = types.ListType(types.int64)
|
||||
list_type2 = types.ListType(types.int64)
|
||||
list_type3 = types.ListType(types.float64)
|
||||
|
||||
assert comparator(list_type1, list_type2)
|
||||
assert not comparator(list_type1, list_type3)
|
||||
|
||||
dict_type1 = types.DictType(types.unicode_type, types.int64)
|
||||
dict_type2 = types.DictType(types.unicode_type, types.int64)
|
||||
dict_type3 = types.DictType(types.unicode_type, types.float64)
|
||||
dict_type4 = types.DictType(types.int64, types.int64)
|
||||
|
||||
assert comparator(dict_type1, dict_type2)
|
||||
assert not comparator(dict_type1, dict_type3) # different value type
|
||||
assert not comparator(dict_type1, dict_type4) # different key type
|
||||
|
||||
|
||||
def test_numba_jit_functions() -> None:
|
||||
"""Test comparator for numba JIT-compiled functions."""
|
||||
try:
|
||||
from numba import jit
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
@jit(nopython=True)
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
@jit(nopython=True)
|
||||
def add2(x, y):
|
||||
return x + y
|
||||
|
||||
@jit(nopython=True)
|
||||
def multiply(x, y):
|
||||
return x * y
|
||||
|
||||
# Compile the functions by calling them
|
||||
add(1, 2)
|
||||
add2(1, 2)
|
||||
multiply(1, 2)
|
||||
|
||||
# Same function should compare equal to itself
|
||||
assert comparator(add, add)
|
||||
|
||||
# Different functions (even with same code) should not compare equal
|
||||
# since they are distinct function objects
|
||||
assert not comparator(add, add2)
|
||||
|
||||
# Different functions with different code should not compare equal
|
||||
assert not comparator(add, multiply)
|
||||
|
||||
|
||||
def test_numba_superset_obj() -> None:
|
||||
"""Test comparator for numba types with superset_obj=True."""
|
||||
try:
|
||||
import numba
|
||||
from numba.typed import Dict as NumbaDict
|
||||
from numba.typed import List as NumbaList
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test NumbaDict with superset_obj=True
|
||||
orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
orig_dict["x"] = 1
|
||||
orig_dict["y"] = 2
|
||||
|
||||
# New dict with same keys - should pass
|
||||
new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_same["x"] = 1
|
||||
new_dict_same["y"] = 2
|
||||
assert comparator(orig_dict, new_dict_same, superset_obj=True)
|
||||
|
||||
# New dict with extra keys - should pass with superset_obj=True
|
||||
new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_superset["x"] = 1
|
||||
new_dict_superset["y"] = 2
|
||||
new_dict_superset["z"] = 3
|
||||
assert comparator(orig_dict, new_dict_superset, superset_obj=True)
|
||||
# But should fail with superset_obj=False
|
||||
assert not comparator(orig_dict, new_dict_superset, superset_obj=False)
|
||||
|
||||
# New dict missing keys - should fail even with superset_obj=True
|
||||
new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_subset["x"] = 1
|
||||
assert not comparator(orig_dict, new_dict_subset, superset_obj=True)
|
||||
|
||||
# New dict with different values - should fail
|
||||
new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_diff["x"] = 1
|
||||
new_dict_diff["y"] = 99
|
||||
assert not comparator(orig_dict, new_dict_diff, superset_obj=True)
|
||||
|
||||
# Test NumbaList with superset_obj=True (lists don't support superset semantics)
|
||||
orig_list = NumbaList([1, 2, 3])
|
||||
new_list_same = NumbaList([1, 2, 3])
|
||||
new_list_longer = NumbaList([1, 2, 3, 4])
|
||||
|
||||
assert comparator(orig_list, new_list_same, superset_obj=True)
|
||||
# Lists must have same length regardless of superset_obj
|
||||
assert not comparator(orig_list, new_list_longer, superset_obj=True)
|
||||
|
||||
# Test empty dict with superset_obj=True
|
||||
empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
non_empty_new["a"] = 1
|
||||
# Empty orig should match any superset
|
||||
assert comparator(empty_orig, non_empty_new, superset_obj=True)
|
||||
assert not comparator(empty_orig, non_empty_new, superset_obj=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for pytest temp path normalization (lines 28-69 in comparator.py)
|
||||
# =============================================================================
|
||||
|
||||
from codeflash.verification.comparator import (
|
||||
PYTEST_TEMP_PATH_PATTERN,
|
||||
_is_temp_path,
|
||||
_normalize_temp_path,
|
||||
)
|
||||
|
||||
|
||||
class TestIsTempPath:
|
||||
"""Tests for the _is_temp_path() function."""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from unittest.mock import Mock
|
|||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
ConcurrencyMetrics,
|
||||
CoverageData,
|
||||
CoverageStatus,
|
||||
FunctionCoverage,
|
||||
|
|
@ -15,12 +16,14 @@ from codeflash.models.models import (
|
|||
TestType,
|
||||
)
|
||||
from codeflash.result.critic import (
|
||||
concurrency_gain,
|
||||
coverage_critic,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
throughput_gain,
|
||||
)
|
||||
from codeflash.verification.parse_test_output import parse_concurrency_metrics
|
||||
|
||||
|
||||
def test_performance_gain() -> None:
|
||||
|
|
@ -569,3 +572,238 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
|
||||
def test_concurrency_gain() -> None:
|
||||
"""Test concurrency_gain calculation."""
|
||||
# Test basic concurrency improvement (blocking -> non-blocking)
|
||||
original = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000, # 10ms
|
||||
concurrent_time_ns=10_000_000, # 10ms (no speedup - blocking)
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0, # sequential/concurrent = 1.0
|
||||
)
|
||||
optimized = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000, # 10ms
|
||||
concurrent_time_ns=1_000_000, # 1ms (10x speedup - non-blocking)
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=10.0, # sequential/concurrent = 10.0
|
||||
)
|
||||
# 900% improvement: (10 - 1) / 1 = 9.0
|
||||
assert concurrency_gain(original, optimized) == 9.0
|
||||
|
||||
# Test no improvement
|
||||
same = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0,
|
||||
)
|
||||
assert concurrency_gain(original, same) == 0.0
|
||||
|
||||
# Test slight improvement
|
||||
slightly_better = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=8_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.25,
|
||||
)
|
||||
# 25% improvement: (1.25 - 1.0) / 1.0 = 0.25
|
||||
assert concurrency_gain(original, slightly_better) == 0.25
|
||||
|
||||
# Test zero original ratio (edge case)
|
||||
zero_ratio = ConcurrencyMetrics(
|
||||
sequential_time_ns=0,
|
||||
concurrent_time_ns=1_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=0.0,
|
||||
)
|
||||
assert concurrency_gain(zero_ratio, optimized) == 0.0
|
||||
|
||||
|
||||
def test_speedup_critic_with_concurrency_metrics() -> None:
|
||||
"""Test speedup_critic with concurrency metrics evaluation."""
|
||||
original_code_runtime = 10000 # 10 microseconds
|
||||
original_async_throughput = 100
|
||||
|
||||
# Original concurrency metrics (blocking code - ratio ~= 1.0)
|
||||
original_concurrency = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0,
|
||||
)
|
||||
|
||||
# Test case 1: Concurrency improves significantly (blocking -> non-blocking)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=10000, # Same runtime
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=10000,
|
||||
async_throughput=100, # Same throughput
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=1_000_000, # 10x faster concurrent execution
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=10.0, # 900% improvement
|
||||
),
|
||||
)
|
||||
|
||||
# Should pass due to concurrency improvement even though runtime/throughput unchanged
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 2: No concurrency improvement (should fall back to other metrics)
|
||||
candidate_result_no_conc = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=100,
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0, # No improvement
|
||||
),
|
||||
)
|
||||
|
||||
# Should pass due to runtime improvement
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result_no_conc,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 3: Concurrency below threshold (20% required)
|
||||
candidate_result_below_threshold = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=10000, # Same runtime
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=10000,
|
||||
async_throughput=100, # Same throughput
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=9_000_000, # Only 11% improvement
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.11,
|
||||
),
|
||||
)
|
||||
|
||||
# Should fail - no metric improves enough
|
||||
assert not speedup_critic(
|
||||
candidate_result=candidate_result_below_threshold,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 4: best_concurrency_ratio_until_now comparison
|
||||
candidate_result_good = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=10000,
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=10000,
|
||||
async_throughput=100,
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=2_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=5.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Should fail when there's a better concurrency ratio already
|
||||
assert not speedup_critic(
|
||||
candidate_result=candidate_result_good,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=10.0, # Better ratio already exists
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
|
||||
def test_concurrency_ratio_display_formatting() -> None:
|
||||
orig_ratio = 0.05
|
||||
cand_ratio = 0.15
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
assert display_string == "Concurrency ratio: 0.05x → 0.15x (+200.0%)"
|
||||
|
||||
orig_ratio = 1.0
|
||||
cand_ratio = 10.0
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
assert display_string == "Concurrency ratio: 1.00x → 10.00x (+900.0%)"
|
||||
|
||||
orig_ratio = 0.01
|
||||
cand_ratio = 0.03
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
assert display_string == "Concurrency ratio: 0.01x → 0.03x (+200.0%)"
|
||||
|
||||
|
||||
def test_parse_concurrency_metrics() -> None:
|
||||
"""Test parse_concurrency_metrics function."""
|
||||
# Test with valid concurrency output
|
||||
stdout = (
|
||||
"!@######CONC:test_module:TestClass:test_func:my_function:0:10000000:1000000:10######@!\n"
|
||||
"!@######CONC:test_module:TestClass:test_func:my_function:1:10000000:1000000:10######@!\n"
|
||||
)
|
||||
test_results = TestResults(perf_stdout=stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_function")
|
||||
assert metrics is not None
|
||||
assert metrics.sequential_time_ns == 10_000_000 # Average of both matches
|
||||
assert metrics.concurrent_time_ns == 1_000_000
|
||||
assert metrics.concurrency_factor == 10
|
||||
assert metrics.concurrency_ratio == 10.0 # 10000000 / 1000000
|
||||
|
||||
# Test with no matching function
|
||||
metrics_wrong_func = parse_concurrency_metrics(test_results, "other_function")
|
||||
assert metrics_wrong_func is None
|
||||
|
||||
# Test with empty stdout
|
||||
empty_results = TestResults(perf_stdout="")
|
||||
metrics_empty = parse_concurrency_metrics(empty_results, "my_function")
|
||||
assert metrics_empty is None
|
||||
|
||||
# Test with None stdout
|
||||
none_results = TestResults(perf_stdout=None)
|
||||
metrics_none = parse_concurrency_metrics(none_results, "my_function")
|
||||
assert metrics_none is None
|
||||
|
||||
# Test with no class name
|
||||
stdout_no_class = "!@######CONC:test_module::test_func:my_function:0:5000000:2500000:10######@!\n"
|
||||
test_results_no_class = TestResults(perf_stdout=stdout_no_class)
|
||||
metrics_no_class = parse_concurrency_metrics(test_results_no_class, "my_function")
|
||||
assert metrics_no_class is not None
|
||||
assert metrics_no_class.concurrency_ratio == 2.0 # 5000000 / 2500000
|
||||
|
|
|
|||
|
|
@ -131,6 +131,46 @@ async def async_function(x: int, y: int) -> int:
|
|||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_application_concurrency_mode(temp_dir):
|
||||
"""Test that CONCURRENCY mode applies the codeflash_concurrency_async decorator."""
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_concurrency_async
|
||||
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_function_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.CONCURRENCY)
|
||||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_class_method_decorator_application(temp_dir):
|
||||
async_class_code = '''
|
||||
|
|
|
|||
9
uv.lock
9
uv.lock
|
|
@ -5640,11 +5640,14 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "wheel"
|
||||
version = "0.45.1"
|
||||
version = "0.46.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545, upload-time = "2024-11-23T00:18:23.513Z" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3a64fa9639b8e290fe8630d8067a66f7c5510845c6d73686ad880c9b04d9/wheel-0.46.2.tar.gz", hash = "sha256:3d79e48fde9847618a5a181f3cc35764c349c752e2fe911e65fa17faab9809b0", size = 60274, upload-time = "2026-01-21T23:55:25.838Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494, upload-time = "2024-11-23T00:18:21.207Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/2c/5e079cefe955ae58e5a052fe037c850ce493eb7269dedeb960237e78fb0f/wheel-0.46.2-py3-none-any.whl", hash = "sha256:33ae60725d69eaa249bc1982e739943c23b34b58d51f1cb6253453773aca6e65", size = 29971, upload-time = "2026-01-21T23:55:24.447Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Reference in a new issue