Merge branch 'main' into jit-docs

This commit is contained in:
Aseem Saxena 2026-01-26 21:16:06 -08:00 committed by GitHub
commit 2dcfba6949
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
46 changed files with 4160 additions and 1657 deletions

View file

@ -22,11 +22,10 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: "0.5.30"
- name: sync uv
run: |
uv venv --seed
uv sync

3
.gitignore vendored
View file

@ -259,3 +259,6 @@ WARP.MD
.mcp.json
.tessl/
tessl.json
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
AGENTS.md

View file

@ -88,6 +88,10 @@ else:
- Commit message body should be concise (1-2 sentences max)
- PR titles should also use conventional format
<!-- Section below is auto-generated by `tessl install` - do not edit manually -->
# Agent Rules <!-- tessl-managed -->
@.tessl/RULES.md follow the [instructions](.tessl/RULES.md)
@AGENTS.md

View file

@ -3,7 +3,7 @@ Business Source License 1.1
Parameters
Licensor: CodeFlash Inc.
Licensed Work: Codeflash Client version 0.19.x
Licensed Work: Codeflash Client version 0.20.x
The Licensed Work is (c) 2024 CodeFlash Inc.
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
Platform. Please visit codeflash.ai for further
information.
Change Date: 2029-12-21
Change Date: 2030-01-26
Change License: MIT

View file

@ -224,7 +224,7 @@ class AiServiceClient:
logger.info("!lsp|Rewriting as a JIT function…")
console.rule()
try:
response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60)
response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=self.timeout)
except requests.exceptions.RequestException as e:
logger.exception(f"Error generating jit rewritten candidate: {e}")
ph("cli-jit-rewrite-error-caught", {"error": str(e)})
@ -460,6 +460,10 @@ class AiServiceClient:
optimized_throughput: str | None = None,
throughput_improvement: str | None = None,
function_references: str | None = None,
acceptance_reason: str | None = None,
original_concurrency_ratio: str | None = None,
optimized_concurrency_ratio: str | None = None,
concurrency_improvement: str | None = None,
codeflash_version: str = codeflash_version,
) -> str:
"""Optimize the given python code for performance by making a request to the Django endpoint.
@ -480,8 +484,12 @@ class AiServiceClient:
- original_throughput: str | None - throughput for the baseline code (operations per second)
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
- throughput_improvement: str | None - throughput improvement percentage
- current codeflash version
- function_references: str | None - where the function is called in the codebase
- acceptance_reason: str | None - why the optimization was accepted (runtime, throughput, or concurrency)
- original_concurrency_ratio: str | None - concurrency ratio for the baseline code
- optimized_concurrency_ratio: str | None - concurrency ratio for the optimized code
- concurrency_improvement: str | None - concurrency improvement percentage
- codeflash_version: str - current codeflash version
Returns
-------
@ -505,6 +513,10 @@ class AiServiceClient:
"optimized_throughput": optimized_throughput,
"throughput_improvement": throughput_improvement,
"function_references": function_references,
"acceptance_reason": acceptance_reason,
"original_concurrency_ratio": original_concurrency_ratio,
"optimized_concurrency_ratio": optimized_concurrency_ratio,
"concurrency_improvement": concurrency_improvement,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
}

View file

@ -84,6 +84,9 @@ def parse_args() -> Namespace:
parser.add_argument(
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
)
parser.add_argument(
"--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code."
)
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
parser.add_argument(
"--verify-setup",

View file

@ -25,12 +25,117 @@ if TYPE_CHECKING:
from codeflash.models.models import FunctionSource
class GlobalFunctionCollector(cst.CSTVisitor):
"""Collects all module-level function definitions (not inside classes or other functions)."""
def __init__(self) -> None:
super().__init__()
self.functions: dict[str, cst.FunctionDef] = {}
self.function_order: list[str] = []
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
if self.scope_depth == 0:
# Module-level function
name = node.name.value
self.functions[name] = node
if name not in self.function_order:
self.function_order.append(name)
self.scope_depth += 1
return True
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
self.scope_depth -= 1
class GlobalFunctionTransformer(cst.CSTTransformer):
"""Transforms/adds module-level functions from the new file to the original file."""
def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None:
super().__init__()
self.new_functions = new_functions
self.new_function_order = new_function_order
self.processed_functions: set[str] = set()
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.scope_depth -= 1
if self.scope_depth > 0:
return updated_node
# Check if this is a module-level function we need to replace
name = original_node.name.value
if name in self.new_functions:
self.processed_functions.add(name)
return self.new_functions[name]
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
self.scope_depth -= 1
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new functions that weren't in the original file
new_statements = list(updated_node.body)
functions_to_append = [
self.new_functions[name]
for name in self.new_function_order
if name not in self.processed_functions and name in self.new_functions
]
if functions_to_append:
# Find the position of the last function or class definition
insert_index = find_insertion_index_after_imports(updated_node)
for i, stmt in enumerate(new_statements):
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
insert_index = i + 1
# Add empty line before each new function
function_nodes = []
for func in functions_to_append:
func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines])
function_nodes.append(func_with_empty_line)
new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:]))
return updated_node.with_changes(body=new_statements)
def collect_referenced_names(node: cst.CSTNode) -> set[str]:
"""Collect all names referenced in a CST node using recursive traversal."""
names: set[str] = set()
def _collect(n: cst.CSTNode) -> None:
if isinstance(n, cst.Name):
names.add(n.value)
# Recursively process all children
for child in n.children:
_collect(child)
_collect(node)
return names
class GlobalAssignmentCollector(cst.CSTVisitor):
"""Collects all global assignment statements."""
def __init__(self) -> None:
super().__init__()
self.assignments: dict[str, cst.Assign] = {}
self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {}
self.assignment_order: list[str] = []
# Track scope depth to identify global assignments
self.scope_depth = 0
@ -72,6 +177,21 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
self.assignment_order.append(name)
return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]:
# Handle annotated assignments like: _CACHE: Dict[str, int] = {}
# Only process module-level annotated assignments with a value
if (
self.scope_depth == 0
and self.if_else_depth == 0
and isinstance(node.target, cst.Name)
and node.value is not None
):
name = node.target.value
self.assignments[name] = node
if name not in self.assignment_order:
self.assignment_order.append(name)
return True
def find_insertion_index_after_imports(node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
@ -103,7 +223,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int:
class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None:
def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None:
super().__init__()
self.new_assignments = new_assignments
self.new_assignment_order = new_assignment_order
@ -150,38 +270,120 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode:
if self.scope_depth > 0 or self.if_else_depth > 0:
return updated_node
# Check if this is a global annotated assignment we need to replace
if isinstance(original_node.target, cst.Name):
name = original_node.target.value
if name in self.new_assignments:
self.processed_assignments.add(name)
return self.new_assignments[name]
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
# Find assignments to append
assignments_to_append = [
self.new_assignments[name]
(name, self.new_assignments[name])
for name in self.new_assignment_order
if name not in self.processed_assignments and name in self.new_assignments
]
if assignments_to_append:
# after last top-level imports
if not assignments_to_append:
return updated_node.with_changes(body=new_statements)
# Collect all class and function names defined in the module
# These are the names that assignments might reference
module_defined_names: set[str] = set()
for stmt in new_statements:
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
module_defined_names.add(stmt.name.value)
# Partition assignments: those that reference module definitions go at the end,
# those that don't can go right after imports
assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
for name, assignment in assignments_to_append:
# Get the value being assigned
if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None:
value_node = assignment.value
else:
# No value to analyze, safe to place after imports
assignments_after_imports.append((name, assignment))
continue
# Collect names referenced in the assignment value
referenced_names = collect_referenced_names(value_node)
# Check if any referenced names are module-level definitions
if referenced_names & module_defined_names:
# This assignment references a class/function, place it after definitions
assignments_after_definitions.append((name, assignment))
else:
# Safe to place right after imports
assignments_after_imports.append((name, assignment))
# Insert assignments that don't depend on module definitions right after imports
if assignments_after_imports:
insert_index = find_insertion_index_after_imports(updated_node)
assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for _, assignment in assignments_after_imports
]
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
# Insert assignments that depend on module definitions after all class/function definitions
if assignments_after_definitions:
# Find the position after the last function or class definition
insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements))
for i, stmt in enumerate(new_statements):
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
insert_index = i + 1
assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for assignment in assignments_to_append
for _, assignment in assignments_after_definitions
]
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
# Add a blank line after the last assignment if needed
after_index = insert_index + len(assignment_lines)
if after_index < len(new_statements):
next_stmt = new_statements[after_index]
# If there's no empty line, add one
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
if not has_empty:
new_statements[after_index] = next_stmt.with_changes(
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
)
return updated_node.with_changes(body=new_statements)
class GlobalStatementTransformer(cst.CSTTransformer):
"""Transformer that appends global statements at the end of the module.
This ensures that global statements (like function calls at module level) are placed
after all functions, classes, and assignments they might reference, preventing NameError
at module load time.
This transformer should be run LAST after GlobalFunctionTransformer and
GlobalAssignmentTransformer have already added their content.
"""
def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None:
super().__init__()
self.global_statements = global_statements
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
if not self.global_statements:
return updated_node
new_statements = list(updated_node.body)
# Add empty line before each statement for readability
statement_lines = [
stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements
]
# Append statements at the end of the module
# This ensures they come after all functions, classes, and assignments
new_statements.extend(statement_lines)
return updated_node.with_changes(body=new_statements)
@ -213,8 +415,8 @@ class GlobalStatementCollector(cst.CSTVisitor):
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
if not self.in_function_or_class:
for statement in node.body:
# Skip imports
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
# Skip imports and assignments (both regular and annotated)
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)):
self.global_statements.append(node)
break
@ -309,40 +511,6 @@ class DottedImportCollector(cst.CSTVisitor):
self._collect_imports_from_block(node.body)
class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""
def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None:
super().__init__()
self.global_statements = global_statements
self.last_import_line = last_import_line
self.current_line = 0
self.inserted = False
def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.Module:
self.current_line += 1
# If we're right after the last import and haven't inserted yet
if self.current_line == self.last_import_line and not self.inserted:
self.inserted = True
return cst.Module(body=[updated_node, *self.global_statements])
return cst.Module(body=[updated_node])
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# If there were no imports, add at the beginning of the module
if self.last_import_line == 0 and not self.inserted:
updated_body = list(updated_node.body)
for stmt in reversed(self.global_statements):
updated_body.insert(0, stmt)
return updated_node.with_changes(body=updated_body)
return updated_node
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
@ -394,34 +562,58 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
continue
unique_global_statements.append(stmt)
mod_dst_code = dst_module_code
# Insert unique global statements if any
if unique_global_statements:
last_import_line = find_last_import_line(dst_module_code)
# Reuse already-parsed dst_module
transformer = ImportInserter(unique_global_statements, last_import_line)
# Use visit inplace, don't parse again
modified_module = dst_module.visit(transformer)
mod_dst_code = modified_module.code
# Parse the code after insertion
original_module = cst.parse_module(mod_dst_code)
else:
# No new statements to insert, reuse already-parsed dst_module
original_module = dst_module
# Reuse already-parsed dst_module
original_module = dst_module
# Parse the src_module_code once only (already done above: src_module)
# Collect assignments from the new file
new_collector = GlobalAssignmentCollector()
src_module.visit(new_collector)
# Only create transformer if there are assignments to insert/transform
if not new_collector.assignments: # nothing to transform
return mod_dst_code
new_assignment_collector = GlobalAssignmentCollector()
src_module.visit(new_assignment_collector)
# Transform the original destination module
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)
# Collect module-level functions from both source and destination
src_function_collector = GlobalFunctionCollector()
src_module.visit(src_function_collector)
return transformed_module.code
dst_function_collector = GlobalFunctionCollector()
original_module.visit(dst_function_collector)
# Filter out functions that already exist in the destination (only add truly new functions)
new_functions = {
name: func
for name, func in src_function_collector.functions.items()
if name not in dst_function_collector.functions
}
new_function_order = [name for name in src_function_collector.function_order if name in new_functions]
# If there are no assignments, no new functions, and no global statements, return unchanged
if not new_assignment_collector.assignments and not new_functions and not unique_global_statements:
return dst_module_code
# The order of transformations matters:
# 1. Functions first - so assignments and statements can reference them
# 2. Assignments second - so they come after functions but before statements
# 3. Global statements last - so they can reference both functions and assignments
# Transform functions if any
if new_functions:
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
original_module = original_module.visit(function_transformer)
# Transform assignments if any
if new_assignment_collector.assignments:
transformer = GlobalAssignmentTransformer(
new_assignment_collector.assignments, new_assignment_collector.assignment_order
)
original_module = original_module.visit(transformer)
# Insert global statements (like function calls at module level) LAST,
# after all functions and assignments are added, to ensure they can reference any
# functions or variables defined in the module
if unique_global_statements:
statement_transformer = GlobalStatementTransformer(unique_global_statements)
original_module = original_module.visit(statement_transformer)
return original_module.code
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:

View file

@ -4,6 +4,7 @@ import asyncio
import gc
import os
import sqlite3
import time
from enum import Enum
from functools import wraps
from pathlib import Path
@ -165,3 +166,45 @@ def codeflash_performance_async(func: F) -> F:
return return_value
return async_wrapper
def codeflash_concurrency_async(func: F) -> F:
"""Measures concurrent vs sequential execution performance for async functions."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
# Phase 1: Sequential execution timing
gc.disable()
try:
seq_start = time.perf_counter_ns()
for _ in range(concurrency_factor):
result = await func(*args, **kwargs)
sequential_time = time.perf_counter_ns() - seq_start
finally:
gc.enable()
# Phase 2: Concurrent execution timing
gc.disable()
try:
conc_start = time.perf_counter_ns()
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
await asyncio.gather(*tasks)
concurrent_time = time.perf_counter_ns() - conc_start
finally:
gc.enable()
# Output parseable metrics
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
return result
return async_wrapper

View file

@ -10,12 +10,20 @@ import sentry_sdk
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_temp_dir
# Known CrossHair limitations that produce invalid Python syntax in generated tests:
# - "<locals>" - higher-order functions returning nested functions
# - " object at 0x" - objects with default __repr__
# - "<list_iterator" - iterator objects
CROSSHAIR_KNOWN_LIMITATION_PATTERNS = ("<locals>", " object at 0x", "<list_iterator")
def is_valid_concolic_test(test_code: str, project_root: Optional[str] = None) -> bool:
try:
ast.parse(test_code)
except SyntaxError:
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
is_known_limitation = any(pattern in test_code for pattern in CROSSHAIR_KNOWN_LIMITATION_PATTERNS)
if not is_known_limitation:
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
return False
temp_path = (codeflash_temp_dir / f"concolic_test_{uuid.uuid4().hex}.py").resolve()

View file

@ -10,6 +10,8 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark
MAX_TEST_FUNCTION_RUNS = 50
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget

View file

@ -1439,9 +1439,12 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
self.added_decorator = False
# Choose decorator based on mode
self.decorator_name = (
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
if mode == TestingMode.BEHAVIOR:
self.decorator_name = "codeflash_behavior_async"
elif mode == TestingMode.CONCURRENCY:
self.decorator_name = "codeflash_concurrency_async"
else:
self.decorator_name = "codeflash_performance_async"
def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
@ -1484,12 +1487,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
return decorator_node.func.value in {
"codeflash_trace_async",
"codeflash_behavior_async",
"codeflash_performance_async",
"codeflash_concurrency_async",
}
return False
@ -1501,6 +1506,14 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
self.mode = mode
self.has_import = False
def _get_decorator_name(self) -> str:
"""Get the decorator name based on the testing mode."""
if self.mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_async"
if self.mode == TestingMode.CONCURRENCY:
return "codeflash_concurrency_async"
return "codeflash_performance_async"
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Check if the async decorator import is already present
if (
@ -1512,9 +1525,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
and node.module.attr.value == "codeflash_wrap_decorator"
and not isinstance(node.names, cst.ImportStar)
):
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
decorator_name = self._get_decorator_name()
for import_alias in node.names:
if import_alias.name.value == decorator_name:
self.has_import = True
@ -1525,9 +1536,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
return updated_node
# Choose import based on mode
decorator_name = (
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
)
decorator_name = self._get_decorator_name()
# Parse the import statement into a CST node
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")

View file

@ -5,6 +5,7 @@ import hashlib
import os
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, cast
import libcst as cst
@ -16,6 +17,7 @@ from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
from codeflash.context.unused_definition_remover import (
collect_top_level_defs_with_usages,
extract_names_from_targets,
get_section_names,
remove_unused_definitions_by_function_names,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
@ -29,14 +31,44 @@ from codeflash.models.models import (
from codeflash.optimization.function_context import belongs_to_function_qualified
if TYPE_CHECKING:
from pathlib import Path
from jedi.api.classes import Name
from libcst import CSTNode
from codeflash.context.unused_definition_remover import UsageInfo
def build_testgen_context(
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
project_root_path: Path,
remove_docstrings: bool, # noqa: FBT001
include_imported_classes: bool, # noqa: FBT001
) -> CodeStringsMarkdown:
"""Build testgen context with optional imported class definitions and external base inits."""
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=remove_docstrings,
code_context_type=CodeContextType.TESTGEN,
)
if include_imported_classes:
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
if external_base_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
return testgen_context
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
@ -120,55 +152,37 @@ def get_code_optimization_context(
logger.debug("Code context has exceeded token limit, removing read-only code")
read_only_context_code = ""
# Extract code context for testgen
testgen_context = extract_code_markdown_context_from_files(
# Extract code context for testgen with progressive fallback for token limits
# Try in order: full context -> remove docstrings -> remove imported classes
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
include_imported_classes=True,
)
# Extract class definitions for imported types from project modules
# This helps the LLM understand class constructors and structure
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
# Merge imported class definitions into testgen context
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
# First try removing docstrings
testgen_context = extract_code_markdown_context_from_files(
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context exceeded token limit, removing docstrings")
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
include_imported_classes=True,
)
# Re-extract imported classes (they may still fit)
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
# If still over limit, try without imported class definitions
testgen_context = extract_code_markdown_context_from_files(
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
logger.debug("Testgen context still exceeded token limit, removing imported class definitions")
testgen_context = build_testgen_context(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
include_imported_classes=False,
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
@ -184,114 +198,6 @@ def get_code_optimization_context(
)
def extract_code_string_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
project_root_path: Path,
remove_docstrings: bool = False, # noqa: FBT001, FBT002
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
) -> CodeString:
"""Extract code context from files containing target functions and their helpers.
This function processes two sets of files:
1. Files containing the function to optimize (fto) and their first-degree helpers
2. Files containing only helpers of helpers (with no overlap with the first set).
For each file, it extracts relevant code based on the specified context type, adds necessary
imports, and combines them.
Args:
----
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
project_root_path: Root path of the project
remove_docstrings: Whether to remove docstrings from the extracted code
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
Returns:
-------
CodeString containing the extracted code context with necessary imports
""" # noqa: D205
# Rearrange to remove overlaps, so we only access each file path once
helpers_of_helpers_no_overlap = defaultdict(set)
for file_path, function_sources in helpers_of_helpers.items():
if file_path in helpers_of_fto:
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
else:
helpers_of_helpers_no_overlap[file_path] = function_sources
final_code_string_context = ""
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
for file_path, function_sources in helpers_of_fto.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
qualified_function_names = {func.qualified_name for func in function_sources}
helpers_of_helpers_qualified_names = {
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
}
code_without_unused_defs = remove_unused_definitions_by_function_names(
original_code, qualified_function_names | helpers_of_helpers_qualified_names
)
code_context = parse_code_and_prune_cst(
code_without_unused_defs,
code_context_type,
qualified_function_names,
helpers_of_helpers_qualified_names,
remove_docstrings,
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if code_context.strip():
final_code_string_context += f"\n{code_context}"
final_code_string_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())),
)
if code_context_type == CodeContextType.READ_WRITABLE:
return CodeString(code=final_code_string_context)
# Extract code from file paths containing helpers of helpers
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
code_without_unused_defs = remove_unused_definitions_by_function_names(
original_code, qualified_helper_function_names
)
code_context = parse_code_and_prune_cst(
code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if code_context.strip():
final_code_string_context += f"\n{code_context}"
final_code_string_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_code_string_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
)
return CodeString(code=final_code_string_context)
def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
@ -526,6 +432,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
the LLM understand the actual class structure (constructors, methods, inheritance)
rather than just seeing import statements.
Also recursively extracts base classes when a class inherits from another class
in the same module, ensuring the full inheritance chain is available for
understanding constructor signatures.
Args:
code_context: The already extracted code context containing imports
project_root_path: Root path of the project
@ -568,6 +478,68 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
class_code_strings: list[CodeString] = []
module_cache: dict[Path, tuple[str, ast.Module]] = {}
def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None:
if module_path in module_cache:
return module_cache[module_path]
try:
module_source = module_path.read_text(encoding="utf-8")
module_tree = ast.parse(module_source)
except Exception:
return None
else:
module_cache[module_path] = (module_source, module_tree)
return module_source, module_tree
def extract_class_and_bases(
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module
) -> None:
"""Extract a class and its base classes recursively from the same module."""
# Skip if already extracted
if (module_path, class_name) in extracted_classes:
return
# Find the class definition in the module
class_node = None
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_node = node
break
if class_node is None:
return
# First, recursively extract base classes from the same module
for base in class_node.bases:
base_name = None
if isinstance(base, ast.Name):
base_name = base.id
elif isinstance(base, ast.Attribute):
# For module.ClassName, we skip (cross-module inheritance)
continue
if base_name and base_name not in existing_definitions:
# Check if base class is defined in the same module
extract_class_and_bases(base_name, module_path, module_source, module_tree)
# Now extract this class (after its bases, so base classes appear first)
if (module_path, class_name) in extracted_classes:
return # Already added by another path
lines = module_source.split("\n")
start_line = class_node.lineno
if class_node.decorator_list:
start_line = min(d.lineno for d in class_node.decorator_list)
class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno])
# Extract imports for the class
class_imports = extract_imports_for_class(module_tree, class_node, module_source)
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
extracted_classes.add((module_path, class_name))
for name, module_name in imported_names.items():
# Skip if already defined in context
if name in existing_definitions:
@ -593,28 +565,14 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
if path_belongs_to_site_packages(module_path):
continue
# Skip if we've already extracted this class
if (module_path, name) in extracted_classes:
# Get module source and tree
result = get_module_source_and_tree(module_path)
if result is None:
continue
module_source, module_tree = result
# Parse the module to find the class definition
module_source = module_path.read_text(encoding="utf-8")
module_tree = ast.parse(module_source)
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == name:
# Extract the class source code
lines = module_source.split("\n")
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
# Also extract any necessary imports for the class (base classes, type hints)
class_imports = _extract_imports_for_class(module_tree, node, module_source)
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
extracted_classes.add((module_path, name))
break
# Extract the class and its base classes
extract_class_and_bases(name, module_path, module_source, module_tree)
except Exception:
logger.debug(f"Error extracting class definition for {name} from {module_name}")
@ -623,10 +581,111 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
return CodeStringsMarkdown(code_strings=class_code_strings)
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
"""Extract __init__ methods from external library base classes.
Scans the code context for classes that inherit from external libraries and extracts
just their __init__ methods. This helps the LLM understand constructor signatures
for mocking or instantiation.
"""
import importlib
import inspect
import textwrap
all_code = "\n".join(cs.code for cs in code_context.code_strings)
try:
tree = ast.parse(all_code)
except SyntaxError:
return CodeStringsMarkdown(code_strings=[])
imported_names: dict[str, str] = {}
external_bases: list[tuple[str, str]] = []
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
if alias.name != "*":
imported_name = alias.asname if alias.asname else alias.name
imported_names[imported_name] = node.module
elif isinstance(node, ast.ClassDef):
for base in node.bases:
base_name = None
if isinstance(base, ast.Name):
base_name = base.id
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
base_name = base.attr
if base_name and base_name in imported_names:
module_name = imported_names[base_name]
if not _is_project_module(module_name, project_root_path):
external_bases.append((base_name, module_name))
if not external_bases:
return CodeStringsMarkdown(code_strings=[])
code_strings: list[CodeString] = []
extracted: set[tuple[str, str]] = set()
for base_name, module_name in external_bases:
if (module_name, base_name) in extracted:
continue
try:
module = importlib.import_module(module_name)
base_class = getattr(module, base_name, None)
if base_class is None:
continue
init_method = getattr(base_class, "__init__", None)
if init_method is None:
continue
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
class_file = Path(inspect.getfile(base_class))
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
except (OSError, TypeError):
continue
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
code_strings.append(CodeString(code=class_source, file_path=class_file))
extracted.add((module_name, base_name))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
"""Check if a module is part of the project (not external/stdlib)."""
import importlib.util
try:
spec = importlib.util.find_spec(module_name)
except (ImportError, ModuleNotFoundError, ValueError):
return False
else:
if spec is None or spec.origin is None:
return False
module_path = Path(spec.origin)
# Check if the module is in site-packages (external dependency)
# This must be checked first because .venv/site-packages is under project root
if path_belongs_to_site_packages(module_path):
return False
# Check if the module is within the project root
return str(module_path).startswith(str(project_root_path) + os.sep)
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
"""Extract import statements needed for a class definition.
This extracts imports for base classes and commonly used type annotations.
This extracts imports for base classes, decorators, and type annotations.
"""
needed_names: set[str] = set()
@ -638,35 +697,139 @@ def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef
# For things like abc.ABC, we need the module name
needed_names.add(base.value.id)
# Get decorator names (e.g., dataclass, field)
for decorator in class_node.decorator_list:
if isinstance(decorator, ast.Name):
needed_names.add(decorator.id)
elif isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Name):
needed_names.add(decorator.func.id)
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
needed_names.add(decorator.func.value.id)
# Get type annotation names from class body (for dataclass fields)
for item in ast.walk(class_node):
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
# Also check for field() calls which are common in dataclasses
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
needed_names.add(item.func.id)
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")
added_imports: set[int] = set() # Track line numbers to avoid duplicates
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names:
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names:
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
return "\n".join(import_lines)
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
"""Recursively collect type annotation names from an AST node."""
if isinstance(node, ast.Name):
names.add(node.id)
elif isinstance(node, ast.Subscript):
collect_names_from_annotation(node.value, names)
collect_names_from_annotation(node.slice, names)
elif isinstance(node, ast.Tuple):
for elt in node.elts:
collect_names_from_annotation(elt, names)
elif isinstance(node, ast.BinOp): # For Union types with | syntax
collect_names_from_annotation(node.left, names)
collect_names_from_annotation(node.right, names)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)
def is_dunder_method(name: str) -> bool:
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
def get_section_names(node: cst.CSTNode) -> list[str]:
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401
possible_sections = ["body", "orelse", "finalbody", "handlers"]
return [sec for sec in possible_sections if hasattr(node, sec)]
class UsedNameCollector(cst.CSTVisitor):
"""Collects all base names referenced in code (for import preservation)."""
def __init__(self) -> None:
self.used_names: set[str] = set()
self.defined_names: set[str] = set()
def visit_Name(self, node: cst.Name) -> None:
self.used_names.add(node.value)
def visit_Attribute(self, node: cst.Attribute) -> bool | None:
base = node.value
while isinstance(base, cst.Attribute):
base = base.value
if isinstance(base, cst.Name):
self.used_names.add(base.value)
return True
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
self.defined_names.add(node.name.value)
return True
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
self.defined_names.add(node.name.value)
return True
def visit_Assign(self, node: cst.Assign) -> bool | None:
for target in node.targets:
names = extract_names_from_targets(target.target)
self.defined_names.update(names)
return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
names = extract_names_from_targets(node.target)
self.defined_names.update(names)
return True
def get_external_names(self) -> set[str]:
return self.used_names - self.defined_names - {"self", "cls"}
def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
"""Extract the names made available by an import statement."""
names: set[str] = set()
if isinstance(import_node, cst.Import):
if isinstance(import_node.names, cst.ImportStar):
return {"*"}
for alias in import_node.names:
if isinstance(alias, cst.ImportAlias):
if alias.asname and isinstance(alias.asname.name, cst.Name):
names.add(alias.asname.name.value)
elif isinstance(alias.name, cst.Name):
names.add(alias.name.value)
elif isinstance(alias.name, cst.Attribute):
# import foo.bar -> accessible as "foo"
base: cst.BaseExpression = alias.name
while isinstance(base, cst.Attribute):
base = base.value
if isinstance(base, cst.Name):
names.add(base.value)
elif isinstance(import_node, cst.ImportFrom):
if isinstance(import_node.names, cst.ImportStar):
return {"*"}
for alias in import_node.names:
if isinstance(alias, cst.ImportAlias):
if alias.asname and isinstance(alias.asname.name, cst.Name):
names.add(alias.asname.name.value)
elif isinstance(alias.name, cst.Name):
names.add(alias.name.value)
return names
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
@ -693,12 +856,22 @@ def parse_code_and_prune_cst(
if code_context_type == CodeContextType.READ_WRITABLE:
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
elif code_context_type == CodeContextType.READ_ONLY:
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
filtered_node, found_target = prune_cst_for_context(
module,
target_functions,
helpers_of_helper_functions,
remove_docstrings=remove_docstrings,
include_target_in_output=False,
include_init_dunder=False,
)
elif code_context_type == CodeContextType.TESTGEN:
filtered_node, found_target = prune_cst_for_testgen_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
filtered_node, found_target = prune_cst_for_context(
module,
target_functions,
helpers_of_helper_functions,
remove_docstrings=remove_docstrings,
include_target_in_output=True,
include_init_dunder=True,
)
elif code_context_type == CodeContextType.HASHING:
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
@ -740,10 +913,29 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
# Do not recurse into nested classes
if prefix:
return None, False
class_name = node.name.value
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
class_prefix = f"{prefix}.{class_name}" if prefix else class_name
# Check if this class contains any target functions
has_target_functions = any(
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
for stmt in node.body.body
)
# If the class is used as a dependency (not containing target functions), keep it entirely
# This handles cases like enums, dataclasses, and other types used by the target function
if (
not has_target_functions
and class_name in defs_with_usages
and defs_with_usages[class_name].used_by_qualified_function
):
return node, True
new_body = []
found_target = False
@ -903,17 +1095,29 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
return (node.with_changes(**updates) if updates else node), True
def prune_cst_for_read_only_code( # noqa: PLR0911
def prune_cst_for_context( # noqa: PLR0911
node: cst.CSTNode,
target_functions: set[str],
helpers_of_helper_functions: set[str],
prefix: str = "",
remove_docstrings: bool = False, # noqa: FBT001, FBT002
include_target_in_output: bool = False, # noqa: FBT001, FBT002
include_init_dunder: bool = False, # noqa: FBT001, FBT002
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for read-only context.
"""Recursively filter the node for code context extraction.
Returns
-------
Args:
node: The CST node to filter
target_functions: Set of qualified function names that are targets
helpers_of_helper_functions: Set of helper function qualified names
prefix: Current qualified name prefix (for class methods)
remove_docstrings: Whether to remove docstrings from output
include_target_in_output: If True, include target functions in output (testgen mode)
If False, exclude target functions (read-only mode)
include_init_dunder: If True, include __init__ in dunder methods (testgen mode)
If False, exclude __init__ from dunder methods (read-only mode)
Returns:
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.
@ -924,17 +1128,28 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
# If it's a target function, remove it but mark found_target = True
# Check if it's a helper of helper function
if qualified_name in helpers_of_helper_functions:
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
return node.with_changes(body=remove_docstring_from_body(node.body)), True
return node, True
# Check if it's a target function
if qualified_name in target_functions:
if include_target_in_output:
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
return node.with_changes(body=remove_docstring_from_body(node.body)), True
return node, True
return None, True
# Keep only dunder methods
if is_dunder_method(node.name.value) and node.name.value != "__init__":
# Check dunder methods
# For read-only mode, exclude __init__; for testgen mode, include all dunders
if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"):
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), False
return node.with_changes(body=remove_docstring_from_body(node.body)), False
return node, False
return None, False
if isinstance(node, cst.ClassDef):
@ -951,8 +1166,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
found_in_class = False
new_class_body: list[CSTNode] = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_read_only_code(
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
filtered, found_target = prune_cst_for_context(
stmt,
target_functions,
helpers_of_helper_functions,
class_prefix,
remove_docstrings=remove_docstrings,
include_target_in_output=include_target_in_output,
include_init_dunder=include_init_dunder,
)
found_in_class |= found_target
if filtered:
@ -981,8 +1202,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_only_code(
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
filtered, found_target = prune_cst_for_context(
child,
target_functions,
helpers_of_helper_functions,
prefix,
remove_docstrings=remove_docstrings,
include_target_in_output=include_target_in_output,
include_init_dunder=include_init_dunder,
)
if filtered:
new_children.append(filtered)
@ -992,122 +1219,19 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
found_any_target |= section_found_target
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_only_code(
original_content,
target_functions,
helpers_of_helper_functions,
prefix,
remove_docstrings=remove_docstrings,
)
found_any_target |= found_target
if filtered:
updates[section] = filtered
if updates:
return (node.with_changes(**updates), found_any_target)
return None, False
def prune_cst_for_testgen_code( # noqa: PLR0911
node: cst.CSTNode,
target_functions: set[str],
helpers_of_helper_functions: set[str],
prefix: str = "",
remove_docstrings: bool = False, # noqa: FBT001, FBT002
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for testgen context.
Returns
-------
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.
"""
if isinstance(node, (cst.Import, cst.ImportFrom)):
return None, False
if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
# If it's a target function, remove it but mark found_target = True
if qualified_name in helpers_of_helper_functions or qualified_name in target_functions:
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), True
return node, True
# Keep all dunder methods
if is_dunder_method(node.name.value):
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), False
return node, False
return None, False
if isinstance(node, cst.ClassDef):
# Do not recurse into nested classes
if prefix:
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
# First pass: detect if there is a target function in the class
found_in_class = False
new_class_body: list[CSTNode] = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_testgen_code(
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
)
found_in_class |= found_target
if filtered:
new_class_body.append(filtered)
if not found_in_class:
return None, False
if remove_docstrings:
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
) if new_class_body else None, True
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
# For other nodes, keep the node and recursively filter children
section_names = get_section_names(node)
if not section_names:
return node, False
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
found_any_target = False
for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, (list, tuple)):
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_testgen_code(
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
if section_found_target or new_children:
found_any_target |= section_found_target
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_testgen_code(
filtered, found_target = prune_cst_for_context(
original_content,
target_functions,
helpers_of_helper_functions,
prefix,
remove_docstrings=remove_docstrings,
include_target_in_output=include_target_in_output,
include_init_dunder=include_init_dunder,
)
found_any_target |= found_target
if filtered:
updates[section] = filtered
if updates:
return (node.with_changes(**updates), found_any_target)

View file

@ -295,11 +295,18 @@ class DependencyCollector(cst.CSTVisitor):
return
if name in self.definitions and name != self.current_top_level_name:
# skip if we are refrencing a class attribute and not a top-level definition
# Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x')
# We only want to track the base/value of attribute access, not the attribute name itself
if self.class_depth > 0:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if parent is not None and isinstance(parent, cst.Attribute):
return
# Check if this Name is the .attr (property name), not the .value (base)
# If it's the .attr, skip it - attribute names aren't references to definitions
if parent.attr is node:
return
# If it's the .value (base), only skip if it's self/cls
if name in ("self", "cls"):
return
self.definitions[self.current_top_level_name].dependencies.add(name)
@ -553,16 +560,6 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
return code
def print_definitions(definitions: dict[str, UsageInfo]) -> None:
"""Print information about each definition without the complex node object, used for debugging."""
print(f"Found {len(definitions)} definitions:")
for name, info in sorted(definitions.items()):
print(f" - Name: {name}")
print(f" Used by qualified function: {info.used_by_qualified_function}")
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
print()
def revert_unused_helper_functions(
project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
) -> None:
@ -637,43 +634,40 @@ def _analyze_imports_in_optimized_code(
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
file_entry = helpers_by_file_and_func[module_name]
if func_name in file_entry:
file_entry[func_name].append(helper)
else:
file_entry[func_name] = [helper]
helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper)
helpers_by_file[module_name].append(helper)
# Optimize attribute lookups and method binding outside the loop
helpers_by_file_and_func_get = helpers_by_file_and_func.get
helpers_by_file_get = helpers_by_file.get
for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
module_name = node.module
if module_name:
file_entry = helpers_by_file_and_func_get(module_name, None)
file_entry = helpers_by_file_and_func.get(module_name)
if file_entry:
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name
helpers = file_entry.get(original_name, None)
helpers = file_entry.get(original_name)
if helpers:
imported_set = imported_names_map[imported_name]
for helper in helpers:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)
imported_set.add(helper.qualified_name)
imported_set.add(helper.fully_qualified_name)
elif isinstance(node, ast.Import):
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name
for helper in helpers_by_file_get(module_name, []):
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)
helpers = helpers_by_file.get(module_name)
if helpers:
imported_set = imported_names_map[f"{imported_name}.{{func}}"]
for helper in helpers:
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
full_call_set = imported_names_map[full_call]
full_call_set.add(helper.qualified_name)
full_call_set.add(helper.fully_qualified_name)
return dict(imported_names_map)
@ -753,27 +747,31 @@ def detect_unused_helper_functions(
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
mapped_names = imported_names_map.get(called_name)
if mapped_names:
called_function_names.update(mapped_names)
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
attr_name = node.func.attr
value_id = node.func.value.id
if value_id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
called_function_names.add(attr_name)
# For class methods, also add the qualified name
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
called_function_names.add(f"{class_name}.{attr_name}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
full_call = f"{value_id}.{attr_name}"
called_function_names.add(full_call)
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
mapped_names = imported_names_map.get(full_call)
if mapped_names:
called_function_names.update(mapped_names)
# Handle nested attribute access like obj.attr.method()
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)
@ -783,6 +781,7 @@ def detect_unused_helper_functions(
# Find helper functions that are no longer called
unused_helpers = []
entrypoint_file_path = function_to_optimize.file_path
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
@ -790,29 +789,30 @@ def detect_unused_helper_functions(
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name
# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
# Check membership efficiently - exit early on first match
if (
helper_qualified_name in called_function_names
or helper_simple_name in called_function_names
or helper_fully_qualified_name in called_function_names
):
is_called = True
# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
elif helper_function.file_path != entrypoint_file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")
# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))
module_call = f"{module_name}.{helper_simple_name}"
is_called = module_call in called_function_names
else:
is_called = False
if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
ret_val = unused_helpers
except Exception as e:
logger.debug(f"Error detecting unused helper functions: {e}")
ret_val = []
return ret_val
return []
else:
return unused_helpers

View file

@ -161,6 +161,7 @@ class BestOptimization(BaseModel):
winning_replay_benchmarking_test_results: Optional[TestResults] = None
line_profiler_test_results: dict
async_throughput: Optional[int] = None
concurrency_metrics: Optional[ConcurrencyMetrics] = None
@dataclass(frozen=True)
@ -172,6 +173,14 @@ class BenchmarkKey:
return f"{self.module_path}::{self.function_name}"
@dataclass
class ConcurrencyMetrics:
sequential_time_ns: int
concurrent_time_ns: int
concurrency_factor: int
concurrency_ratio: float # sequential_time / concurrent_time
@dataclass
class BenchmarkDetail:
benchmark_name: str
@ -336,6 +345,7 @@ class OptimizedCandidateResult(BaseModel):
optimization_candidate_index: int
total_candidate_timing: int
async_throughput: Optional[int] = None
concurrency_metrics: Optional[ConcurrencyMetrics] = None
class GeneratedTests(BaseModel):
@ -557,6 +567,7 @@ class OriginalCodeBaseline(BaseModel):
runtime: int
coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
concurrency_metrics: Optional[ConcurrencyMetrics] = None
class CoverageStatus(Enum):
@ -648,6 +659,7 @@ class TestingMode(enum.Enum):
BEHAVIOR = "behavior"
PERFORMANCE = "performance"
LINE_PROFILE = "line_profile"
CONCURRENCY = "concurrency"
# TODO this class is duplicated in codeflash_capture

View file

@ -100,7 +100,9 @@ from codeflash.models.models import (
)
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import (
concurrency_gain,
coverage_critic,
get_acceptance_reason,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
@ -112,7 +114,11 @@ from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
from codeflash.verification.parse_test_output import (
calculate_function_throughput_from_test_results,
parse_concurrency_metrics,
parse_test_results,
)
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
@ -125,6 +131,7 @@ if TYPE_CHECKING:
from codeflash.models.models import (
BenchmarkKey,
CodeStringsMarkdown,
ConcurrencyMetrics,
CoverageData,
FunctionCalledInTest,
FunctionSource,
@ -603,7 +610,9 @@ class FunctionOptimizer:
):
console.rule()
new_code_context = code_context
if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
if (
self.is_numerical_code and not self.args.no_jit_opts
): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code(
code_context.read_writable_code.markdown, self.function_trace_id
)
@ -632,7 +641,7 @@ class FunctionOptimizer:
read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code,
run_experiment=should_run_experiment,
is_numerical_code=self.is_numerical_code,
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
)
concurrent.futures.wait([future_tests, future_optimizations])
@ -735,6 +744,13 @@ class FunctionOptimizer:
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X")
# Display concurrency metrics if available
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
else:
tree.add("This candidate is faster than the original code. 🚀")
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
@ -753,6 +769,14 @@ class FunctionOptimizer:
)
tree.add(f"Async throughput: {candidate_result.async_throughput} executions")
tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%")
# Display concurrency metrics if available
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
tree.add(
f"(Runtime for reference: {humanize_runtime(candidate_result.best_test_runtime)} over "
f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})"
@ -819,6 +843,7 @@ class FunctionOptimizer:
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
async_throughput=candidate_result.async_throughput,
concurrency_metrics=candidate_result.concurrency_metrics,
)
return best_optimization, benchmark_tree
@ -863,6 +888,7 @@ class FunctionOptimizer:
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
async_throughput=valid_opt.async_throughput,
concurrency_metrics=valid_opt.concurrency_metrics,
)
valid_candidates_with_shorter_code.append(new_best_opt)
diff_lens_list.append(
@ -1014,6 +1040,8 @@ class FunctionOptimizer:
best_runtime_until_now=None,
original_async_throughput=original_code_baseline.async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_ratio_until_now=None,
) and quantity_of_tests_critic(candidate_result)
tree = self.build_runtime_info_tree(
@ -1132,7 +1160,7 @@ class FunctionOptimizer:
)
if self.experiment_id
else None,
is_numerical_code=self.is_numerical_code,
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
)
processor = CandidateProcessor(
@ -1776,6 +1804,14 @@ class FunctionOptimizer:
fto_benchmark_timings=self.function_benchmark_timings,
total_benchmark_timings=self.total_benchmark_timings,
)
acceptance_reason = get_acceptance_reason(
original_runtime_ns=original_code_baseline.runtime,
optimized_runtime_ns=best_optimization.runtime,
original_async_throughput=original_code_baseline.async_throughput,
optimized_async_throughput=best_optimization.async_throughput,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
)
explanation = Explanation(
raw_explanation_message=best_optimization.candidate.explanation,
winning_behavior_test_results=best_optimization.winning_behavior_test_results,
@ -1787,6 +1823,9 @@ class FunctionOptimizer:
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
original_async_throughput=original_code_baseline.async_throughput,
best_async_throughput=best_optimization.async_throughput,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_metrics=best_optimization.concurrency_metrics,
acceptance_reason=acceptance_reason,
)
self.replace_function_and_helpers_with_optimized_code(
@ -1884,6 +1923,9 @@ class FunctionOptimizer:
original_throughput_str = None
optimized_throughput_str = None
throughput_improvement_str = None
original_concurrency_ratio_str = None
optimized_concurrency_ratio_str = None
concurrency_improvement_str = None
if (
self.function_to_optimize.is_async
@ -1898,6 +1940,14 @@ class FunctionOptimizer:
)
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
if original_code_baseline.concurrency_metrics is not None and best_optimization.concurrency_metrics is not None:
original_concurrency_ratio_str = f"{original_code_baseline.concurrency_metrics.concurrency_ratio:.2f}x"
optimized_concurrency_ratio_str = f"{best_optimization.concurrency_metrics.concurrency_ratio:.2f}x"
conc_improvement_value = concurrency_gain(
original_code_baseline.concurrency_metrics, best_optimization.concurrency_metrics
)
concurrency_improvement_str = f"{conc_improvement_value * 100:.1f}%"
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
source_code=code_context.read_writable_code.flat,
dependency_code=code_context.read_only_context_code,
@ -1915,6 +1965,10 @@ class FunctionOptimizer:
optimized_throughput=optimized_throughput_str,
throughput_improvement=throughput_improvement_str,
function_references=function_references,
acceptance_reason=explanation.acceptance_reason.value,
original_concurrency_ratio=original_concurrency_ratio_str,
optimized_concurrency_ratio=optimized_concurrency_ratio_str,
concurrency_improvement=concurrency_improvement_str,
)
new_explanation = Explanation(
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
@ -1927,6 +1981,9 @@ class FunctionOptimizer:
benchmark_details=explanation.benchmark_details,
original_async_throughput=explanation.original_async_throughput,
best_async_throughput=explanation.best_async_throughput,
original_concurrency_metrics=explanation.original_concurrency_metrics,
best_concurrency_metrics=explanation.best_concurrency_metrics,
acceptance_reason=explanation.acceptance_reason,
)
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
@ -2155,12 +2212,22 @@ class FunctionOptimizer:
logger.debug(f"Total original code runtime (ns): {total_timing}")
async_throughput = None
concurrency_metrics = None
if self.function_to_optimize.is_async:
async_throughput = calculate_function_throughput_from_test_results(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
)
if concurrency_metrics:
logger.debug(
f"Original concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
)
if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -2175,6 +2242,7 @@ class FunctionOptimizer:
coverage_results=coverage_results,
line_profile_results=line_profile_results,
async_throughput=async_throughput,
concurrency_metrics=concurrency_metrics,
),
functions_to_remove,
)
@ -2341,12 +2409,23 @@ class FunctionOptimizer:
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
candidate_async_throughput = None
candidate_concurrency_metrics = None
if self.function_to_optimize.is_async:
candidate_async_throughput = calculate_function_throughput_from_test_results(
candidate_benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
# Run concurrency benchmark for candidate
candidate_concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=candidate_helper_code, test_env=test_env
)
if candidate_concurrency_metrics:
logger.debug(
f"Candidate concurrency metrics: ratio={candidate_concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={candidate_concurrency_metrics.sequential_time_ns}ns, conc={candidate_concurrency_metrics.concurrent_time_ns}ns"
)
if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
@ -2367,6 +2446,7 @@ class FunctionOptimizer:
optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing,
async_throughput=candidate_async_throughput,
concurrency_metrics=candidate_concurrency_metrics,
)
)
@ -2572,3 +2652,57 @@ class FunctionOptimizer:
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
return line_profile_results
def run_concurrency_benchmark(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]
) -> ConcurrencyMetrics | None:
"""Run concurrency benchmark to measure sequential vs concurrent execution for async functions.
This benchmark detects blocking vs non-blocking async code by comparing:
- Sequential execution time (running N iterations one after another)
- Concurrent execution time (running N iterations in parallel with asyncio.gather)
Blocking code (like time.sleep) will have similar sequential and concurrent times.
Non-blocking code (like asyncio.sleep) will be much faster when run concurrently.
Returns:
ConcurrencyMetrics if benchmark ran successfully, None otherwise.
"""
if not self.function_to_optimize.is_async:
return None
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
try:
# Add concurrency decorator to the source function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
)
# Run the concurrency benchmark tests
concurrency_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE, # Use performance mode for running
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=5.0, # Short benchmark time
enable_coverage=False,
code_context=code_context,
pytest_min_loops=1,
pytest_max_loops=3,
)
except Exception as e:
logger.debug(f"Concurrency benchmark failed: {e}")
return None
finally:
# Restore original code
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
# Parse concurrency metrics from stdout
if concurrency_results and concurrency_results.perf_stdout:
return parse_concurrency_metrics(concurrency_results, self.function_to_optimize.function_name)
return None

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING
from codeflash.code_utils import env_utils
from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD,
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD,
MIN_IMPROVEMENT_THRESHOLD,
MIN_TESTCASE_PASSED_THRESHOLD,
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
@ -12,7 +14,14 @@ from codeflash.code_utils.config_consts import (
from codeflash.models import models
if TYPE_CHECKING:
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
from codeflash.models.models import ConcurrencyMetrics, CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
class AcceptanceReason(Enum):
RUNTIME = "runtime"
THROUGHPUT = "throughput"
CONCURRENCY = "concurrency"
NONE = "none"
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
@ -36,6 +45,22 @@ def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> f
return (optimized_throughput - original_throughput) / original_throughput
def concurrency_gain(original_metrics: ConcurrencyMetrics, optimized_metrics: ConcurrencyMetrics) -> float:
"""Calculate concurrency ratio improvement.
Returns the relative improvement in concurrency ratio.
Higher is better - means the optimized code scales better with concurrent execution.
concurrency_ratio = sequential_time / concurrent_time
A ratio of 10 means concurrent execution is 10x faster than sequential.
"""
if original_metrics.concurrency_ratio == 0:
return 0.0
return (
optimized_metrics.concurrency_ratio - original_metrics.concurrency_ratio
) / original_metrics.concurrency_ratio
def speedup_critic(
candidate_result: OptimizedCandidateResult,
original_code_runtime: int,
@ -44,10 +69,12 @@ def speedup_critic(
disable_gh_action_noise: bool = False,
original_async_throughput: int | None = None,
best_throughput_until_now: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
best_concurrency_ratio_until_now: float | None = None,
) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
Evaluates both runtime performance and async throughput improvements.
Evaluates runtime performance, async throughput, and concurrency improvements.
For runtime performance:
- Ensures the optimization is actually faster than the original code, above the noise floor.
@ -58,6 +85,10 @@ def speedup_critic(
For async throughput (when available):
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
- Throughput improvements complement runtime improvements for async functions
For concurrency (when available):
- Evaluates concurrency ratio improvements using MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
"""
# Runtime performance evaluation
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
@ -86,14 +117,78 @@ def speedup_critic(
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
)
# Concurrency evaluation
concurrency_improved = False
concurrency_is_best = True
if original_concurrency_metrics is not None and candidate_result.concurrency_metrics is not None:
conc_gain = concurrency_gain(original_concurrency_metrics, candidate_result.concurrency_metrics)
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
concurrency_is_best = (
best_concurrency_ratio_until_now is None
or candidate_result.concurrency_metrics.concurrency_ratio > best_concurrency_ratio_until_now
)
# Accept if ANY of: runtime, throughput, or concurrency improves significantly
if original_async_throughput is not None and candidate_result.async_throughput is not None:
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
throughput_acceptance = throughput_improved and throughput_is_best
runtime_acceptance = runtime_improved and runtime_is_best
return throughput_acceptance or runtime_acceptance
concurrency_acceptance = concurrency_improved and concurrency_is_best
return throughput_acceptance or runtime_acceptance or concurrency_acceptance
return runtime_improved and runtime_is_best
def get_acceptance_reason(
original_runtime_ns: int,
optimized_runtime_ns: int,
*,
original_async_throughput: int | None = None,
optimized_async_throughput: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
) -> AcceptanceReason:
"""Determine why an optimization was accepted.
Returns the primary reason for acceptance, with priority:
concurrency > throughput > runtime (for async code).
"""
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
if env_utils.is_ci():
noise_floor = noise_floor * 2
perf_gain = performance_gain(original_runtime_ns=original_runtime_ns, optimized_runtime_ns=optimized_runtime_ns)
runtime_improved = perf_gain > noise_floor
throughput_improved = False
if (
original_async_throughput is not None
and optimized_async_throughput is not None
and original_async_throughput > 0
):
throughput_gain_value = throughput_gain(
original_throughput=original_async_throughput, optimized_throughput=optimized_async_throughput
)
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
concurrency_improved = False
if original_concurrency_metrics is not None and optimized_concurrency_metrics is not None:
conc_gain = concurrency_gain(original_concurrency_metrics, optimized_concurrency_metrics)
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
# Return reason with priority: concurrency > throughput > runtime
if original_async_throughput is not None and optimized_async_throughput is not None:
if concurrency_improved:
return AcceptanceReason.CONCURRENCY
if throughput_improved:
return AcceptanceReason.THROUGHPUT
if runtime_improved:
return AcceptanceReason.RUNTIME
return AcceptanceReason.NONE
if runtime_improved:
return AcceptanceReason.RUNTIME
return AcceptanceReason.NONE
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
test_results = candidate_result.behavior_test_results
report = test_results.get_test_pass_fail_report_by_type()

View file

@ -11,8 +11,8 @@ from rich.table import Table
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.models import BenchmarkDetail, TestResults
from codeflash.result.critic import throughput_gain
from codeflash.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults
from codeflash.result.critic import AcceptanceReason, concurrency_gain, throughput_gain
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@ -27,31 +27,44 @@ class Explanation:
benchmark_details: Optional[list[BenchmarkDetail]] = None
original_async_throughput: Optional[int] = None
best_async_throughput: Optional[int] = None
original_concurrency_metrics: Optional[ConcurrencyMetrics] = None
best_concurrency_metrics: Optional[ConcurrencyMetrics] = None
acceptance_reason: AcceptanceReason = AcceptanceReason.RUNTIME
@property
def perf_improvement_line(self) -> str:
# speedup property already handles choosing between runtime and throughput
improvement_type = {
AcceptanceReason.RUNTIME: "runtime",
AcceptanceReason.THROUGHPUT: "throughput",
AcceptanceReason.CONCURRENCY: "concurrency",
AcceptanceReason.NONE: "",
}.get(self.acceptance_reason, "")
if improvement_type:
return f"{self.speedup_pct} {improvement_type} improvement ({self.speedup_x} faster)."
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
@property
def speedup(self) -> float:
runtime_improvement = (self.original_runtime_ns / self.best_runtime_ns) - 1
# Use throughput improvement if we have async metrics and throughput is better
"""Returns the improvement value for the metric that caused acceptance."""
if (
self.original_async_throughput is not None
self.acceptance_reason == AcceptanceReason.CONCURRENCY
and self.original_concurrency_metrics
and self.best_concurrency_metrics
):
return concurrency_gain(self.original_concurrency_metrics, self.best_concurrency_metrics)
if (
self.acceptance_reason == AcceptanceReason.THROUGHPUT
and self.original_async_throughput is not None
and self.best_async_throughput is not None
and self.original_async_throughput > 0
):
throughput_improvement = throughput_gain(
return throughput_gain(
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
)
# Use throughput metrics if throughput improvement is better or runtime got worse
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
return throughput_improvement
return runtime_improvement
return (self.original_runtime_ns / self.best_runtime_ns) - 1
@property
def speedup_x(self) -> str:
@ -108,7 +121,22 @@ class Explanation:
console.print(table)
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
if self.original_async_throughput is not None and self.best_async_throughput is not None:
if (
self.acceptance_reason == AcceptanceReason.CONCURRENCY
and self.original_concurrency_metrics
and self.best_concurrency_metrics
):
orig_ratio = self.original_concurrency_metrics.concurrency_ratio
best_ratio = self.best_concurrency_metrics.concurrency_ratio
performance_description = (
f"Concurrency ratio improved from {orig_ratio:.2f}x to {best_ratio:.2f}x "
f"(concurrent execution now {best_ratio:.2f}x faster than sequential)\n\n"
)
elif (
self.acceptance_reason == AcceptanceReason.THROUGHPUT
and self.original_async_throughput is not None
and self.best_async_throughput is not None
):
performance_description = (
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
f"(runtime: {original_runtime_human}{best_runtime_human})\n\n"

View file

@ -138,6 +138,13 @@ def main(args: Namespace | None = None) -> ArgumentParser:
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
else:
env["PYTHONPATH"] = project_root_str
# Disable JIT compilation to ensure tracing captures all function calls
env["NUMBA_DISABLE_JIT"] = str(1)
env["TORCHDYNAMO_DISABLE"] = str(1)
env["PYTORCH_JIT"] = str(0)
env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
env["JAX_DISABLE_JIT"] = str(1)
processes.append(
subprocess.Popen(
[
@ -175,6 +182,13 @@ def main(args: Namespace | None = None) -> ArgumentParser:
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
else:
env["PYTHONPATH"] = project_root_str
# Disable JIT compilation to ensure tracing captures all function calls
env["NUMBA_DISABLE_JIT"] = str(1)
env["TORCHDYNAMO_DISABLE"] = str(1)
env["PYTORCH_JIT"] = str(0)
env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
env["JAX_DISABLE_JIT"] = str(1)
subprocess.run(
[

View file

@ -15,6 +15,8 @@ from typing import Callable
import dill as pickle
from dill import PicklingWarning
from codeflash.picklepatch.pickle_patcher import PicklePatcher
warnings.filterwarnings("ignore", category=PicklingWarning)
@ -148,18 +150,29 @@ def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is
print(f"!######{test_stdout_tag}######!")
# Capture instance state after initialization
if hasattr(args[0], "__dict__"):
instance_state = args[
0
].__dict__ # self is always the first argument, this is ensured during instrumentation
# self is always the first argument, this is ensured during instrumentation
instance = args[0]
if hasattr(instance, "__dict__"):
instance_state = instance.__dict__
elif hasattr(instance, "__slots__"):
# For classes using __slots__, capture slot values
instance_state = {
slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot)
}
else:
raise ValueError("Instance state could not be captured.")
# For C extension types or other special classes (e.g., Playwright's Page),
# capture all non-private, non-callable attributes
instance_state = {
attr: getattr(instance, attr)
for attr in dir(instance)
if not attr.startswith("_") and not callable(getattr(instance, attr, None))
}
codeflash_cur.execute(
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
)
# Write to sqlite
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(instance_state)
pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state)
codeflash_cur.execute(
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(

View file

@ -24,6 +24,7 @@ HAS_TORCH = find_spec("torch") is not None
HAS_JAX = find_spec("jax") is not None
HAS_XARRAY = find_spec("xarray") is not None
HAS_TENSORFLOW = find_spec("tensorflow") is not None
HAS_NUMBA = find_spec("numba") is not None
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
# These paths vary between test runs but are logically equivalent
@ -156,6 +157,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
range,
slice,
OrderedDict,
types.GenericAlias,
),
):
return orig == new
@ -255,6 +257,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
return False
return True
# Handle mappingproxy (read-only dict view, commonly seen as class.__dict__)
if isinstance(orig, types.MappingProxyType):
return comparator(dict(orig), dict(new), superset_obj)
# Handle dict view types (dict_keys, dict_values, dict_items)
# Use type name checking since these are not directly importable types
type_name = type(orig).__name__
@ -296,8 +302,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
# fails at "ufunc 'isfinite' not supported for the input types"
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
return np.isclose(orig, new)
if isinstance(orig, (np.floating, np.complexfloating)):
return np.isclose(orig, new, equal_nan=True)
if isinstance(orig, (np.integer, np.bool_, np.byte)):
return orig == new
@ -383,6 +389,42 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if isinstance(orig, torch.device):
return orig == new
if HAS_NUMBA:
import numba # type: ignore # noqa: PGH003
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
# Handle numba typed List
if isinstance(orig, NumbaList):
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
# Handle numba typed Dict
if isinstance(orig, NumbaDict):
if superset_obj:
# Allow new dict to have more keys, but all orig keys must exist with equal values
return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig)
if len(orig) != len(new):
return False
for key in orig:
if key not in new:
return False
if not comparator(orig[key], new[key], superset_obj):
return False
return True
# Handle numba type objects (e.g., numba.int64, numba.float64, numba.Array, etc.)
if isinstance(orig, numba.core.types.Type):
return orig == new
# Handle numba JIT-compiled functions (CPUDispatcher, etc.)
if isinstance(orig, Dispatcher):
# Compare by identity of the underlying Python function
# Two JIT functions are equal if they wrap the same Python function
return orig.py_func is new.py_func
if HAS_PYRSISTENT:
import pyrsistent # type: ignore # noqa: PGH003

View file

@ -19,6 +19,14 @@ reprlib_repr.maxstring = 1500
test_diff_repr = reprlib_repr.repr
def safe_repr(obj: object) -> str:
"""Safely get repr of an object, handling Mock objects with corrupted state."""
try:
return repr(obj)
except (AttributeError, TypeError, RecursionError) as e:
return f"<repr failed: {type(e).__name__}: {e}>"
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
# This is meant to be only called with test results for the first loop index
if len(original_results) == 0 or len(candidate_results) == 0:
@ -77,8 +85,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
test_diffs.append(
TestDiff(
scope=TestDiffScope.RETURN_VALUE,
original_value=test_diff_repr(repr(original_test_result.return_value)),
candidate_value=test_diff_repr(repr(cdd_test_result.return_value)),
original_value=test_diff_repr(safe_repr(original_test_result.return_value)),
candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)),
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
candidate_pytest_error=cdd_pytest_error,
original_pass=original_test_result.did_pass,

View file

@ -20,7 +20,14 @@ from codeflash.code_utils.code_utils import (
module_name_from_file_path,
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
from codeflash.models.models import (
ConcurrencyMetrics,
FunctionTestInvocation,
InvocationId,
TestResults,
TestType,
VerificationType,
)
from codeflash.verification.coverage_utils import CoverageUtils
if TYPE_CHECKING:
@ -64,6 +71,54 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
return function_throughput
# Pattern for concurrency benchmark output:
# !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
_concurrency_pattern = re.compile(r"!@######CONC:([^:]*):([^:]*):([^:]*):([^:]*):([^:]*):(\d+):(\d+):(\d+)######@!")
def parse_concurrency_metrics(test_results: TestResults, function_name: str) -> ConcurrencyMetrics | None:
"""Parse concurrency benchmark results from test output.
Format: !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
Returns ConcurrencyMetrics with:
- sequential_time_ns: Total time for N sequential executions
- concurrent_time_ns: Total time for N concurrent executions
- concurrency_factor: N (number of concurrent executions)
- concurrency_ratio: sequential_time / concurrent_time (higher = better concurrency)
"""
if not test_results.perf_stdout:
return None
matches = _concurrency_pattern.findall(test_results.perf_stdout)
if not matches:
return None
# Aggregate metrics for the target function
total_seq, total_conc, factor, count = 0, 0, 0, 0
for match in matches:
# match[3] is function_name
if len(match) >= 8 and match[3] == function_name:
total_seq += int(match[5])
total_conc += int(match[6])
factor = int(match[7])
count += 1
if count == 0:
return None
avg_seq = total_seq / count
avg_conc = total_conc / count
ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0
return ConcurrencyMetrics(
sequential_time_ns=int(avg_seq),
concurrent_time_ns=int(avg_conc),
concurrency_factor=factor,
concurrency_ratio=ratio,
)
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
"""Resolve test file path from pytest's test class path.

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.19.1"
__version__ = "0.20.0"

View file

@ -1,191 +0,0 @@
---
title: "Flags Reference"
description: "Complete reference for all Codeflash CLI flags and options"
icon: "list"
sidebarTitle: "Flags Reference"
keywords: ["flags", "options", "arguments", "command line"]
---
# Flags Reference
Complete reference for all Codeflash CLI flags and command-line options.
---
## Main Command Flags
| Flag | Type | Description |
|------|------|-------------|
| `--file` | `PATH` | Optimize only this file |
| `--function` | `NAME` | Optimize only this function (requires `--file`) |
| `--all` | `[PATH]` | Optimize all functions. Optional path to start from |
| `--replay-test` | `PATH` | Path to replay test file(s) |
| `--benchmark` | flag | Enable benchmark mode |
| `--no-pr` | flag | Don't create PR, update locally |
| `--no-gen-tests` | flag | Don't generate tests |
| `--no-draft` | flag | Skip draft PRs |
| `--worktree` | flag | Use git worktree |
| `--staging-review` | flag | Upload to staging |
| `--verbose` / `-v` | flag | Verbose debug output |
| `--verify-setup` | flag | Run setup verification |
| `--version` | flag | Show version |
---
## Configuration Override Flags
Override settings from `pyproject.toml` via command line.
| Flag | Type | Description |
|------|------|-------------|
| `--config-file` | `PATH` | Path to pyproject.toml |
| `--module-root` | `PATH` | Python module root directory |
| `--tests-root` | `PATH` | Tests directory |
| `--benchmarks-root` | `PATH` | Benchmarks directory |
<Accordion title="Complete Examples">
<Tabs>
<Tab title="Linux/macOS">
```bash
# Override config file location
codeflash --file src/app.py --function main --config-file configs/pyproject.toml --no-pr
# Override module root
codeflash --file src/app.py --function main --module-root src --no-pr
# Override tests root
codeflash --file src/app.py --function main --tests-root tests/unit --no-pr
# Combine multiple overrides
codeflash --file src/app.py --function main \
--module-root src \
--tests-root tests \
--no-pr
```
</Tab>
<Tab title="Windows">
```powershell
# Override config file location
codeflash --file src\app.py --function main --config-file configs\pyproject.toml --no-pr
# Override module root
codeflash --file src\app.py --function main --module-root src --no-pr
# Override tests root
codeflash --file src\app.py --function main --tests-root tests\unit --no-pr
```
</Tab>
</Tabs>
</Accordion>
---
## Optimize Subcommand Flags
Flags specific to the `codeflash optimize` command.
| Flag | Type | Description |
|------|------|-------------|
| `--output` | `PATH` | Trace file output path (default: `codeflash.trace`) |
| `--timeout` | `INT` | Maximum trace time in seconds |
| `--max-function-count` | `INT` | Max times to trace a function (default: 100) |
| `--config-file-path` | `PATH` | Path to pyproject.toml |
| `--trace-only` | flag | Only trace, don't optimize |
<Info>
The `--output` flag specifies where to save the trace file. If not specified, it defaults to `codeflash.trace` in the current directory.
</Info>
---
## Behavior Flags
Control how Codeflash behaves during optimization.
| Flag | Description |
|------|-------------|
| `--no-pr` | Run locally without creating a pull request |
| `--no-gen-tests` | Use only existing tests, skip test generation |
| `--no-draft` | Skip optimization for draft PRs (CI mode) |
| `--worktree` | Use git worktree for isolated optimization |
| `--staging-review` | Upload optimizations to staging for review |
| `--verbose` / `-v` | Enable verbose debug logging |
<Accordion title="Complete Examples">
```bash
# Local optimization only
codeflash --file src/app.py --function main --no-pr
# Use only existing tests
codeflash --file src/app.py --function main --no-gen-tests --no-pr
# Enable verbose logging
codeflash --file src/app.py --function main --verbose --no-pr
# Use worktree for isolation
codeflash --file src/app.py --function main --worktree --no-pr
# Upload to staging
codeflash --all --staging-review --no-pr
```
</Accordion>
---
## Flag Combinations
Common flag combinations for different use cases:
### Local Development
```bash
# Optimize locally with verbose output
codeflash --file src/app.py --function main --no-pr --verbose
```
### CI/CD Pipeline
```bash
# Skip draft PRs and use existing tests only
codeflash --all --no-draft --no-gen-tests
```
### Debugging
```bash
# Trace only with custom output and timeout
codeflash optimize app.py --trace-only --output debug.trace --timeout 60
```
### Custom Configuration
```bash
# Override multiple config settings
codeflash --file src/app.py --function main \
--module-root src \
--tests-root tests/unit \
--benchmarks-root tests/benchmarks \
--no-pr
```
---
## Next Steps
<CardGroup cols={2}>
<Card
title="Optimization Commands"
icon="bullseye"
href="/cli-reference/optimization"
>
Learn how to use optimization commands
</Card>
<Card
title="Troubleshooting"
icon="wrench"
href="/cli-reference/troubleshooting"
>
Fix common issues
</Card>
</CardGroup>

View file

@ -1,208 +0,0 @@
---
title: "CLI Reference"
description: "Complete command-line reference for Codeflash CLI commands, flags, and options"
icon: "terminal"
sidebarTitle: "Overview"
keywords:
[
"CLI",
"command line",
"commands",
"flags",
"options",
"reference",
"terminal",
]
---
# Codeflash CLI Reference
Complete command-line reference for all Codeflash commands, flags, and options with practical examples you can run directly in your terminal.
<Info>
**Prerequisites** - Ensure Codeflash is installed in your Python environment
and you have a configured `pyproject.toml` in your project.
</Info>
---
## Quick Start
<Tabs>
<Tab title="Linux/macOS">
```bash
# Activate virtual environment (if using one)
source .venv/bin/activate
# Verify installation
codeflash --version
```
</Tab>
<Tab title="Windows">
```powershell
# Activate virtual environment (if using one)
.venv\Scripts\activate
# Verify installation
codeflash --version
```
</Tab>
</Tabs>
---
## Common Workflows
### 1. First-Time Setup
<Steps>
<Step title="Install Codeflash">
```bash
pip install codeflash
```
</Step>
<Step title="Initialize Project">
```bash
codeflash init
```
</Step>
<Step title="Verify Setup">
```bash
codeflash --verify-setup
```
</Step>
<Step title="Run First Optimization">
```bash
codeflash --file src/main.py --function my_function --no-pr
```
</Step>
</Steps>
---
### 2. Optimize a Workflow
<Steps>
<Step title="Trace Your Script">
```bash
codeflash optimize my_script.py --arg1 value1
```
</Step>
<Step title="Review Optimizations">
Check the generated PR or local changes for optimization suggestions.
</Step>
</Steps>
---
### 3. CI/CD Integration
<Steps>
<Step title="Set Up GitHub Actions">
```bash
codeflash init-actions
```
</Step>
<Step title="Merge the Workflow PR">
Review and merge the generated GitHub Actions workflow.
</Step>
<Step title="Automatic Optimization">
Codeflash will now optimize code in every PR automatically!
</Step>
</Steps>
---
## Help & Version
```bash
# Display version
codeflash --version
# Main help
codeflash --help
# Subcommand help
codeflash optimize --help
codeflash init --help
```
---
## Documentation Structure
This CLI reference is organized into the following sections:
<CardGroup cols={2}>
<Card
title="Setup Commands"
icon="wrench"
href="/cli-reference/setup"
>
Initialize projects, set up GitHub Actions, and verify installation
</Card>
<Card
title="Optimization Commands"
icon="bullseye"
href="/cli-reference/optimization"
>
Optimize single functions or entire codebases
</Card>
<Card
title="Tracing & Workflows"
icon="route"
href="/cli-reference/tracing"
>
Trace script execution and optimize based on real usage
</Card>
<Card
title="Flags Reference"
icon="list"
href="/cli-reference/flags"
>
Complete reference for all command-line flags
</Card>
<Card
title="Troubleshooting"
icon="wrench"
href="/cli-reference/troubleshooting"
>
Solutions for common CLI issues
</Card>
</CardGroup>
---
## Next Steps
<CardGroup cols={2}>
<Card
title="Optimize a Function"
icon="bullseye"
href="/optimizing-with-codeflash/one-function"
>
Learn how to optimize individual functions
</Card>
<Card
title="Trace & Optimize"
icon="route"
href="/optimizing-with-codeflash/trace-and-optimize"
>
Optimize entire workflows with tracing
</Card>
<Card
title="GitHub Actions"
icon="github"
href="/optimizing-with-codeflash/codeflash-github-actions"
>
Set up continuous optimization
</Card>
<Card
title="Configuration"
icon="gear"
href="/configuration"
>
Advanced configuration options
</Card>
</CardGroup>

View file

@ -1,172 +0,0 @@
---
title: "Optimization Commands"
description: "Optimize single functions or entire codebases with Codeflash CLI"
icon: "bullseye"
sidebarTitle: "Optimization Commands"
keywords: ["optimization", "function", "file", "all", "commands"]
---
# Optimization Commands
Commands for optimizing individual functions or entire codebases.
---
## Optimize a Single Function
Target a specific function in a file for optimization.
```bash
codeflash --file <path/to/file.py> --function <function_name>
```
<Accordion title="Complete Examples">
<Tabs>
<Tab title="Linux/macOS">
```bash
# Basic optimization (creates PR)
codeflash --file src/utils.py --function calculate_metrics
# Local optimization only (no PR)
codeflash --file src/utils.py --function calculate_metrics --no-pr
# With verbose output
codeflash --file src/utils.py --function calculate_metrics --no-pr --verbose
```
</Tab>
<Tab title="Windows">
```powershell
# Basic optimization (creates PR)
codeflash --file src\utils.py --function calculate_metrics
# Local optimization only (no PR)
codeflash --file src\utils.py --function calculate_metrics --no-pr
# With verbose output
codeflash --file src\utils.py --function calculate_metrics --no-pr --verbose
```
</Tab>
</Tabs>
</Accordion>
<Warning>
**Important**: The file must be within your configured `module-root`
directory. Files outside `module-root` will be ignored with "Functions outside
module-root" message.
</Warning>
---
## Optimize All Functions
Optimize all functions in your entire codebase or a specific directory.
```bash
# Optimize entire codebase
codeflash --all
# Optimize specific directory
codeflash --all src/core/
```
<Accordion title="Complete Examples">
<Tabs>
<Tab title="Linux/macOS">
```bash
# Optimize all (creates PRs)
codeflash --all
# Optimize all locally (no PRs)
codeflash --all --no-pr
# Optimize specific directory
codeflash --all src/algorithms/ --no-pr
# Skip draft PRs in CI
codeflash --all --no-draft
```
</Tab>
<Tab title="Windows">
```powershell
# Optimize all (creates PRs)
codeflash --all
# Optimize all locally (no PRs)
codeflash --all --no-pr
# Optimize specific directory
codeflash --all src\algorithms\ --no-pr
# Skip draft PRs in CI
codeflash --all --no-draft
```
</Tab>
</Tabs>
</Accordion>
<Info>
When using `--all`, Codeflash will:
- Discover all optimizable functions in your codebase
- Create separate PRs for each function (or update locally with `--no-pr`)
- Process functions in batches to avoid overwhelming your repository
</Info>
---
## Benchmark Mode
Optimize code based on performance benchmarks using pytest-benchmark format.
```bash
codeflash --file <file.py> --benchmark --benchmarks-root <path>
```
<Accordion title="Complete Examples">
<Tabs>
<Tab title="Linux/macOS">
```bash
# With benchmarks-root flag
codeflash --file src/core.py --benchmark --benchmarks-root tests/benchmarks --no-pr
# If benchmarks-root is in pyproject.toml
codeflash --file src/core.py --benchmark --no-pr
```
</Tab>
<Tab title="Windows">
```powershell
# With benchmarks-root flag
codeflash --file src\core.py --benchmark --benchmarks-root tests\benchmarks --no-pr
# If benchmarks-root is in pyproject.toml
codeflash --file src\core.py --benchmark --no-pr
```
</Tab>
</Tabs>
</Accordion>
<Warning>
The `--benchmarks-root` directory must exist and be configured either via
`pyproject.toml` or the command-line flag.
</Warning>
---
## Next Steps
<CardGroup cols={2}>
<Card
title="Tracing & Workflows"
icon="route"
href="/cli-reference/tracing"
>
Learn about trace-based optimization
</Card>
<Card
title="Flags Reference"
icon="list"
href="/cli-reference/flags"
>
Complete flag reference
</Card>
</CardGroup>

View file

@ -1,125 +0,0 @@
---
title: "Setup Commands"
description: "Initialize projects, set up GitHub Actions, and verify Codeflash installation"
icon: "wrench"
sidebarTitle: "Setup Commands"
keywords: ["setup", "init", "installation", "github actions", "verify"]
---
# Setup Commands
Commands for initializing Codeflash in your project, setting up continuous optimization, and verifying your installation.
---
## `codeflash init`
Initialize Codeflash for your Python project. This creates the configuration in `pyproject.toml`.
<CodeGroup>
```bash Basic
codeflash init
```
```bash With Formatter Override
codeflash init --override-formatter-check
```
</CodeGroup>
<Tip>
The `init` command will guide you through an interactive setup process,
including API key configuration, module selection, and GitHub App
installation.
</Tip>
**What it does:**
- Prompts for your Python module directory (`module-root`)
- Prompts for your test directory (`tests-root`)
- Configures code formatter preferences
- Sets up telemetry preferences
- Optionally installs the Codeflash VS Code extension
- Optionally sets up GitHub Actions workflow
---
## `codeflash init-actions`
Set up GitHub Actions workflow for continuous optimization on every pull request.
```bash
codeflash init-actions
```
**What it does:**
- Creates a workflow file in `.github/workflows/`
- Opens a PR with the workflow configuration
- Requires the Codeflash GitHub App to be installed
<Warning>
This command requires the Codeflash GitHub App to be installed on your repository. If you haven't installed it, you'll be prompted with a link to do so.
</Warning>
---
## `codeflash vscode-install`
Install the Codeflash extension for VS Code, Cursor, or Windsurf.
```bash
codeflash vscode-install
```
**What it does:**
- Detects which editor you're using (VS Code, Cursor, or Windsurf)
- Downloads and installs the appropriate extension
- Works with both Marketplace and Open VSX sources
<Tip>
This command is also run automatically during `codeflash init` if you choose to install the extension.
</Tip>
---
## `codeflash --verify-setup`
Verify your Codeflash installation by running a sample optimization.
```bash
codeflash --verify-setup
```
**What it does:**
- Creates a temporary demo file
- Runs a sample optimization
- Verifies all components are working correctly
- Cleans up the demo file afterward
<Note>
This command takes about 3 minutes to complete. It's a great way to ensure everything is set up correctly before optimizing your actual code.
</Note>
---
## Next Steps
<CardGroup cols={2}>
<Card
title="Optimization Commands"
icon="bullseye"
href="/cli-reference/optimization"
>
Learn how to optimize functions
</Card>
<Card
title="Flags Reference"
icon="list"
href="/cli-reference/flags"
>
Complete flag reference
</Card>
</CardGroup>

View file

@ -1,213 +0,0 @@
---
title: "Tracing & Workflows"
description: "Trace script execution and optimize functions based on real-world usage"
icon: "route"
sidebarTitle: "Tracing & Workflows"
keywords: ["tracing", "optimize", "workflow", "replay test", "pytest"]
---
# Tracing & Workflows
Trace Python script execution and optimize functions based on real-world usage patterns.
---
## `codeflash optimize`
Trace a Python script's execution and optimize functions based on real-world usage.
```bash
codeflash optimize <script.py> [script_args]
```
<Accordion title="Complete Examples">
<Tabs>
<Tab title="Linux/macOS">
```bash
# Basic trace and optimize
codeflash optimize app.py
# With script arguments
codeflash optimize process.py --input data.csv --output results.json
# Custom trace output file
codeflash optimize app.py --output custom_trace.trace
# With timeout (30 seconds)
codeflash optimize long_running_script.py --timeout 30
# Limit function trace count
codeflash optimize app.py --max-function-count 50
# Specify config file
codeflash optimize app.py --config-file-path pyproject.toml
# Local only (no PR)
codeflash optimize app.py --no-pr
```
</Tab>
<Tab title="Windows">
```powershell
# Basic trace and optimize
codeflash optimize app.py
# With script arguments
codeflash optimize process.py --input data.csv --output results.json
# Custom trace output file
codeflash optimize app.py --output custom_trace.trace
# With timeout (30 seconds)
codeflash optimize long_running_script.py --timeout 30
# Limit function trace count
codeflash optimize app.py --max-function-count 50
# Specify config file
codeflash optimize app.py --config-file-path pyproject.toml
# Local only (no PR)
codeflash optimize app.py --no-pr
```
</Tab>
</Tabs>
</Accordion>
**How it works:**
1. Runs your script with the provided arguments
2. Traces all function calls during execution
3. Identifies which functions are called and how often
4. Generates replay tests based on actual usage
5. Optimizes the traced functions
---
## Trace with pytest
Optimize functions called during pytest test execution.
<Tabs>
<Tab title="Linux/macOS">
```bash
# Trace pytest tests
codeflash optimize -m pytest tests/
# Trace specific test file
codeflash optimize -m pytest tests/test_core.py
# With pytest arguments
codeflash optimize -m pytest tests/ -v --tb=short
```
</Tab>
<Tab title="Windows">
```powershell
# Trace pytest tests
codeflash optimize -m pytest tests\
# Trace specific test file
codeflash optimize -m pytest tests\test_core.py
# With pytest arguments
codeflash optimize -m pytest tests\ -v --tb=short
```
</Tab>
</Tabs>
<Tip>
Tracing pytest tests is great for optimizing functions that are heavily used in your test suite, ensuring optimizations work correctly with your existing tests.
</Tip>
---
## Trace Only (Generate Replay Tests)
Create trace files and replay tests without running optimization.
<Tabs>
<Tab title="Linux/macOS">
```bash
# Trace only - generates replay test
codeflash optimize app.py --output trace_file.trace --trace-only
# Then optimize using the replay test
codeflash --replay-test tests/test_app_py__replay_test_0.py --no-pr
```
</Tab>
<Tab title="Windows">
```powershell
# Trace only - generates replay test
codeflash optimize app.py --output trace_file.trace --trace-only
# Then optimize using the replay test
codeflash --replay-test tests\test_app_py__replay_test_0.py --no-pr
```
</Tab>
</Tabs>
<Note>
**Replay test naming**: Files are named based on the traced script path. For
`src/app.py`, the replay test will be named like
`test_srcapp_py__replay_test_0.py`.
</Note>
**Use cases for trace-only:**
- Generate replay tests for later optimization
- Debug tracing issues without running full optimization
- Create reusable test cases from script execution
---
## Replay Test Optimization
Optimize functions using previously generated replay tests.
```bash
codeflash --replay-test <path/to/replay_test.py>
```
<Accordion title="Complete Examples">
<Tabs>
<Tab title="Linux/macOS">
```bash
# Optimize using replay test
codeflash --replay-test tests/test_app_py__replay_test_0.py --no-pr
# Multiple replay tests
codeflash --replay-test tests/test_*.py --no-pr
```
</Tab>
<Tab title="Windows">
```powershell
# Optimize using replay test
codeflash --replay-test tests\test_app_py__replay_test_0.py --no-pr
# Multiple replay tests (use Get-ChildItem for globbing)
codeflash --replay-test (Get-ChildItem tests\test_*.py) --no-pr
```
</Tab>
</Tabs>
</Accordion>
---
## Next Steps
<CardGroup cols={2}>
<Card
title="Optimization Commands"
icon="bullseye"
href="/cli-reference/optimization"
>
Learn about function optimization
</Card>
<Card
title="Flags Reference"
icon="list"
href="/cli-reference/flags"
>
Complete flag reference
</Card>
</CardGroup>

View file

@ -1,157 +0,0 @@
---
title: "CLI Troubleshooting"
description: "Solutions for common Codeflash CLI issues and errors"
icon: "wrench"
sidebarTitle: "Troubleshooting"
keywords: ["troubleshooting", "errors", "issues", "problems", "debugging"]
---
# CLI Troubleshooting
Solutions for common issues when using the Codeflash CLI.
---
## Common Issues
<AccordionGroup>
<Accordion title="'Functions outside module-root' Error">
**Problem**: Function not found because file is outside `module-root`.
**Solution**: Ensure your file is within the `module-root` directory specified in `pyproject.toml`.
```bash
# Check your module-root
grep "module-root" pyproject.toml
# Use the correct path (e.g., if module-root is "src")
codeflash --file src/myfile.py --function my_function --no-pr
```
</Accordion>
<Accordion title="'benchmarks-root must be specified' Error">
**Problem**: Using `--benchmark` without specifying benchmarks directory.
**Solution**: Either add `benchmarks-root` to `pyproject.toml` or use the flag:
```bash
codeflash --file src/app.py --benchmark --benchmarks-root tests/benchmarks --no-pr
```
</Accordion>
<Accordion title="Replay Test File Not Found">
**Problem**: Replay test filename doesn't match expected path.
**Solution**: Replay tests include the module path in their name. Check the actual filename:
```bash
# Linux/macOS
ls tests/test_*replay*.py
# Windows
dir tests\test_*replay*.py
```
<Note>
Replay test files are named based on the traced script path. For `src/app.py`,
the replay test will be named like `test_srcapp_py__replay_test_0.py`.
</Note>
</Accordion>
<Accordion title="GitHub App Required">
**Problem**: PR creation fails due to missing GitHub App.
**Solution**: Install the Codeflash GitHub App or use `--no-pr` for local optimization:
```bash
# Local optimization
codeflash --file src/app.py --function main --no-pr
# Or install the GitHub App
# https://github.com/apps/codeflash-ai/installations/select_target
```
</Accordion>
<Accordion title="Module Not Found Errors">
**Problem**: Codeflash can't find your Python modules.
**Solution**:
1. Verify `module-root` is correctly set in `pyproject.toml`
2. Ensure you're running from the project root
3. Check that your Python environment has all dependencies installed
```bash
# Verify module-root
cat pyproject.toml | grep module-root
# Check Python path
python -c "import sys; print(sys.path)"
```
</Accordion>
<Accordion title="Test Generation Fails">
**Problem**: Codeflash can't generate tests for your function.
**Solution**:
1. Ensure your function has a return statement
2. Check that the function is not a property or class method with special decorators
3. Use `--no-gen-tests` to skip test generation and use existing tests only
```bash
codeflash --file src/app.py --function main --no-gen-tests --no-pr
```
</Accordion>
<Accordion title="Optimization Timeout">
**Problem**: Optimization takes too long or times out.
**Solution**:
1. Use `--verbose` to see what's happening
2. For tracing, use `--timeout` to limit trace duration
3. For large functions, consider breaking them down
```bash
# Limit trace time
codeflash optimize app.py --timeout 30
# See detailed progress
codeflash --file src/app.py --function main --verbose --no-pr
```
</Accordion>
</AccordionGroup>
---
## Getting Help
If you're still experiencing issues:
1. **Check the logs**: Use `--verbose` flag to see detailed output
2. **Verify setup**: Run `codeflash --verify-setup` to check your installation
3. **Check configuration**: Ensure `pyproject.toml` is correctly configured
4. **View help**: Run `codeflash --help` or `codeflash <command> --help`
---
## Next Steps
<CardGroup cols={2}>
<Card
title="Setup Commands"
icon="wrench"
href="/cli-reference/setup"
>
Review setup and initialization
</Card>
<Card
title="Flags Reference"
icon="list"
href="/cli-reference/flags"
>
Complete flag reference
</Card>
</CardGroup>

View file

@ -18,33 +18,13 @@
{
"tab": "Documentation",
"groups": [
{
"group": "🚀 Quickstart",
"pages": ["getting-started/local-installation"]
},
{
"group": "🏠 Overview",
"pages": ["index"]
},
{
"group": "📖 Codeflash CLI",
"pages": [
"cli-reference/index",
"cli-reference/setup",
"cli-reference/optimization",
"cli-reference/tracing",
"cli-reference/flags",
"cli-reference/troubleshooting"
]
},
{
"group": "🛠 IDE Extension",
"pages": [
"editor-plugins/vscode/index",
"editor-plugins/vscode/features",
"editor-plugins/vscode/configuration",
"editor-plugins/vscode/troubleshooting"
]
"group": "🚀 Quickstart",
"pages": ["getting-started/local-installation"]
},
{
"group": "⚡ Optimizing with Codeflash",
@ -62,6 +42,15 @@
"optimizing-with-codeflash/review-optimizations"
]
},
{
"group": "🛠 IDE Extension",
"pages": [
"editor-plugins/vscode/index",
"editor-plugins/vscode/features",
"editor-plugins/vscode/configuration",
"editor-plugins/vscode/troubleshooting"
]
},
{
"group": "🧠 Core Concepts",
"pages": [

View file

@ -146,7 +146,4 @@ When configuration issues are detected, the extension displays clear error messa
<Card title="Project Configuration" icon="file-code" href="/configuration">
Complete pyproject.toml reference
</Card>
<Card title="Codeflash CLI" icon="terminal" href="/cli-reference/index">
Command-line options
</Card>
</CardGroup>

View file

@ -204,7 +204,6 @@ The extension works alongside the Codeflash CLI. You can:
- **Use extension for interactive work** — Optimize individual functions as you code
- **Mix both** — The extension picks up CLI results when you return to the editor
For CLI documentation, see the [Codeflash CLI](/cli-reference/index).
---
@ -220,9 +219,6 @@ For CLI documentation, see the [Codeflash CLI](/cli-reference/index).
<Card title="Troubleshooting" icon="wrench" href="/editor-plugins/vscode/troubleshooting">
Fix common issues
</Card>
<Card title="Codeflash CLI" icon="terminal" href="/cli-reference/index">
Command-line interface docs
</Card>
</CardGroup>
---

View file

@ -208,8 +208,5 @@ If you're still experiencing issues:
<Card title="Configuration" icon="gear" href="/editor-plugins/vscode/configuration">
Customize extension settings
</Card>
<Card title="Codeflash CLI" icon="terminal" href="/cli-reference/index">
Command-line interface docs
</Card>
</CardGroup>

View file

@ -8,6 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="main.py",
min_improvement_x=0.1,
expected_acceptance_reason="concurrency",
coverage_expectations=[
CoverageExpectation(
function_name="retry_with_backoff",

View file

@ -37,6 +37,7 @@ class TestConfig:
benchmarks_root: Optional[pathlib.Path] = None
use_worktree: bool = False
no_gen_tests: bool = False
expected_acceptance_reason: Optional[str] = None # "runtime", "throughput", "concurrency"
def clear_directory(directory_path: str | pathlib.Path) -> None:
@ -176,7 +177,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
logging.error("Failed to find performance improvement message")
return False
improvement_match = re.search(r"📈 ([\d,]+)% improvement", stdout)
improvement_match = re.search(r"📈 ([\d,]+)% (?:(\w+) )?improvement", stdout)
if not improvement_match:
logging.error("Could not find improvement percentage in output")
return False
@ -193,6 +194,15 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
return False
if config.expected_acceptance_reason is not None:
actual_reason = improvement_match.group(2)
if not actual_reason:
logging.error("Could not find acceptance reason type in output")
return False
if actual_reason != config.expected_acceptance_reason:
logging.error(f"Expected acceptance reason '{config.expected_acceptance_reason}', got '{actual_reason}'")
return False
if config.expected_unit_tests_count is not None:
# Match the global test discovery message from optimizer.py which counts test invocations
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"

View file

@ -0,0 +1,304 @@
from __future__ import annotations
import asyncio
import os
import sys
import time
import pytest
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_concurrency_async
from codeflash.models.models import ConcurrencyMetrics, TestResults
from codeflash.verification.parse_test_output import parse_concurrency_metrics
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestConcurrencyAsyncDecorator:
"""Integration tests for codeflash_concurrency_async decorator."""
@pytest.fixture
def concurrency_env_setup(self, request):
"""Set up environment variables for concurrency testing."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestConcurrencyAsyncDecorator",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CONCURRENCY_FACTOR": "5", # Use smaller factor for faster tests
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.mark.asyncio
async def test_concurrency_decorator_nonblocking_function(self, concurrency_env_setup, capsys):
"""Test that non-blocking async functions show high concurrency ratio."""
@codeflash_concurrency_async
async def nonblocking_sleep(duration: float) -> str:
await asyncio.sleep(duration)
return "done"
result = await nonblocking_sleep(0.01)
assert result == "done"
captured = capsys.readouterr()
output = captured.out
# Verify the output format
assert "!@######CONC:" in output
assert "######@!" in output
# Parse the output manually to verify format
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
assert len(lines) == 1
line = lines[0]
# Format: !@######CONC:{test_module}:{test_class}:{test_function}:{function_name}:{loop_index}:{seq_time}:{conc_time}:{factor}######@!
assert "nonblocking_sleep" in line
assert ":5######@!" in line # concurrency factor
# Extract timing values
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
# parts should be: [test_module, test_class, test_function, function_name, loop_index, seq_time, conc_time, factor]
assert len(parts) == 8
seq_time = int(parts[5])
conc_time = int(parts[6])
factor = int(parts[7])
assert seq_time > 0
assert conc_time > 0
assert factor == 5
# For non-blocking async, concurrent time should be much less than sequential
# Sequential runs 5 iterations of 10ms = ~50ms
# Concurrent runs 5 iterations in parallel = ~10ms
# So ratio should be around 5 (with some overhead tolerance)
ratio = seq_time / conc_time if conc_time > 0 else 1.0
assert ratio > 2.0, f"Non-blocking function should have ratio > 2.0, got {ratio}"
@pytest.mark.asyncio
async def test_concurrency_decorator_blocking_function(self, concurrency_env_setup, capsys):
"""Test that blocking functions show low concurrency ratio (~1.0)."""
@codeflash_concurrency_async
async def blocking_sleep(duration: float) -> str:
time.sleep(duration) # Blocking sleep
return "done"
result = await blocking_sleep(0.005) # 5ms blocking
assert result == "done"
captured = capsys.readouterr()
output = captured.out
assert "!@######CONC:" in output
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
assert len(lines) == 1
line = lines[0]
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
assert len(parts) == 8
seq_time = int(parts[5])
conc_time = int(parts[6])
# For blocking code, sequential and concurrent times should be similar
# Because time.sleep blocks the entire event loop
ratio = seq_time / conc_time if conc_time > 0 else 1.0
# Blocking code should have ratio close to 1.0 (within reasonable tolerance)
assert ratio < 2.0, f"Blocking function should have ratio < 2.0, got {ratio}"
@pytest.mark.asyncio
async def test_concurrency_decorator_with_computation(self, concurrency_env_setup, capsys):
"""Test concurrency with CPU-bound computation."""
@codeflash_concurrency_async
async def compute_intensive(n: int) -> int:
# CPU-bound work (blocked by GIL in concurrent execution)
total = 0
for i in range(n):
total += i * i
return total
result = await compute_intensive(10000)
assert result == sum(i * i for i in range(10000))
captured = capsys.readouterr()
output = captured.out
assert "!@######CONC:" in output
assert "compute_intensive" in output
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestParseConcurrencyMetrics:
"""Integration tests for parse_concurrency_metrics function."""
def test_parse_concurrency_metrics_from_real_output(self):
"""Test parsing concurrency metrics from simulated stdout."""
# Simulate stdout from codeflash_concurrency_async decorator
perf_stdout = """Some other output
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
More output here
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
metrics = parse_concurrency_metrics(test_results, "my_async_func")
assert metrics is not None
assert isinstance(metrics, ConcurrencyMetrics)
assert metrics.sequential_time_ns == 50000000
assert metrics.concurrent_time_ns == 10000000
assert metrics.concurrency_factor == 5
assert metrics.concurrency_ratio == 5.0 # 50M / 10M = 5.0
def test_parse_concurrency_metrics_multiple_entries(self):
"""Test parsing when multiple concurrency entries exist."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:target_func:1:40000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
metrics = parse_concurrency_metrics(test_results, "target_func")
assert metrics is not None
# Should average the two entries for target_func
# (40M + 60M) / 2 = 50M seq, (10M + 10M) / 2 = 10M conc
assert metrics.sequential_time_ns == 50000000
assert metrics.concurrent_time_ns == 10000000
assert metrics.concurrency_ratio == 5.0
def test_parse_concurrency_metrics_no_match(self):
"""Test parsing when function name doesn't match."""
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
"""
test_results = TestResults(
test_results=[],
perf_stdout=perf_stdout,
)
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
assert metrics is None
def test_parse_concurrency_metrics_empty_stdout(self):
"""Test parsing with empty stdout."""
test_results = TestResults(
test_results=[],
perf_stdout="",
)
metrics = parse_concurrency_metrics(test_results, "any_func")
assert metrics is None
def test_parse_concurrency_metrics_none_stdout(self):
"""Test parsing with None stdout."""
test_results = TestResults(
test_results=[],
perf_stdout=None,
)
metrics = parse_concurrency_metrics(test_results, "any_func")
assert metrics is None
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
class TestConcurrencyRatioComparison:
"""Test comparing blocking vs non-blocking concurrency ratios."""
@pytest.fixture
def comparison_env_setup(self, request):
"""Set up environment variables for comparison testing."""
original_env = {}
test_env = {
"CODEFLASH_LOOP_INDEX": "1",
"CODEFLASH_TEST_MODULE": __name__,
"CODEFLASH_TEST_CLASS": "TestConcurrencyRatioComparison",
"CODEFLASH_TEST_FUNCTION": request.node.name,
"CODEFLASH_CONCURRENCY_FACTOR": "10",
}
for key, value in test_env.items():
original_env[key] = os.environ.get(key)
os.environ[key] = value
yield test_env
for key, original_value in original_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
@pytest.mark.asyncio
async def test_blocking_vs_nonblocking_comparison(self, comparison_env_setup, capsys):
"""Compare concurrency ratios between blocking and non-blocking implementations."""
@codeflash_concurrency_async
async def blocking_impl() -> str:
time.sleep(0.002) # 2ms blocking
return "blocking"
@codeflash_concurrency_async
async def nonblocking_impl() -> str:
await asyncio.sleep(0.002) # 2ms non-blocking
return "nonblocking"
# Run blocking version
await blocking_impl()
blocking_output = capsys.readouterr().out
# Run non-blocking version
await nonblocking_impl()
nonblocking_output = capsys.readouterr().out
# Parse blocking metrics
blocking_line = [l for l in blocking_output.split("\n") if "!@######CONC:" in l][0]
blocking_parts = blocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
blocking_seq = int(blocking_parts[5])
blocking_conc = int(blocking_parts[6])
blocking_ratio = blocking_seq / blocking_conc if blocking_conc > 0 else 1.0
# Parse non-blocking metrics
nonblocking_line = [l for l in nonblocking_output.split("\n") if "!@######CONC:" in l][0]
nonblocking_parts = nonblocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
nonblocking_seq = int(nonblocking_parts[5])
nonblocking_conc = int(nonblocking_parts[6])
nonblocking_ratio = nonblocking_seq / nonblocking_conc if nonblocking_conc > 0 else 1.0
# Non-blocking should have significantly higher concurrency ratio
assert nonblocking_ratio > blocking_ratio, (
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than "
f"blocking ratio ({blocking_ratio:.2f})"
)
# The difference should be substantial (non-blocking should be at least 2x better)
ratio_improvement = nonblocking_ratio / blocking_ratio if blocking_ratio > 0 else 0
assert ratio_improvement > 2.0, (
f"Non-blocking should show >2x improvement in concurrency ratio, got {ratio_improvement:.2f}x"
)

File diff suppressed because it is too large Load diff

View file

@ -2119,7 +2119,6 @@ print("Hello world")
expected_code = """import numpy as np
a = 6
if 2<3:
a=4
else:

View file

@ -1602,7 +1602,94 @@ def calculate_portfolio_metrics(
# now the test should match and no diffs should be found
assert len(diffs) == 0
assert matched
finally:
test_path.unlink(missing_ok=True)
fto_file_path.unlink(missing_ok=True)
fto_file_path.unlink(missing_ok=True)
def test_codeflash_capture_with_slots_class() -> None:
"""Test that codeflash_capture works with classes that use __slots__ instead of __dict__."""
test_code = """
from code_to_optimize.tests.pytest.sample_code import SlotsClass
import unittest
def test_slots_class():
obj = SlotsClass(10, "test")
assert obj.x == 10
assert obj.y == "test"
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
sample_code = f"""
from codeflash.verification.codeflash_capture import codeflash_capture
class SlotsClass:
__slots__ = ('x', 'y')
@codeflash_capture(function_name="SlotsClass.__init__", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}")
def __init__(self, x, y):
self.x = x
self.y = y
"""
test_file_name = "test_slots_class_temp.py"
test_path = test_dir / test_file_name
test_path_perf = test_dir / "test_slots_class_temp_perf.py"
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
project_root_path = (Path(__file__).parent / "..").resolve()
sample_code_path = test_dir / "sample_code.py"
try:
with test_path.open("w") as f:
f.write(test_code)
with sample_code_path.open("w") as f:
f.write(sample_code)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
fto = FunctionToOptimize(
function_name="__init__",
file_path=sample_code_path,
parents=[FunctionParent(name="SlotsClass", type="ClassDef")],
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=test_path,
test_type=test_type,
original_file_path=test_path,
benchmarking_file_path=test_path_perf,
)
]
)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
# Test should pass and capture the slots values
assert len(test_results) == 1
assert test_results[0].did_pass
# The return value should contain the slot values
assert test_results[0].return_value[0]["x"] == 10
assert test_results[0].return_value[0]["y"] == "test"
finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)

View file

@ -17,7 +17,14 @@ import pytest
from codeflash.either import Failure, Success
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
from codeflash.verification.comparator import comparator, _extract_exception_from_message, _get_wrapped_exception
from codeflash.verification.comparator import (
PYTEST_TEMP_PATH_PATTERN,
_extract_exception_from_message,
_get_wrapped_exception,
_is_temp_path,
_normalize_temp_path,
comparator,
)
from codeflash.verification.equivalence import compare_test_results
@ -2309,6 +2316,77 @@ def test_dict_views() -> None:
assert not comparator(d.items(), [("a", 1), ("b", 2)])
def test_mappingproxy() -> None:
"""Test comparator support for types.MappingProxyType (read-only dict view)."""
import types
# Basic equality
mp1 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
assert comparator(mp1, mp2)
# Different values
mp3 = types.MappingProxyType({"a": 1, "b": 2, "c": 4})
assert not comparator(mp1, mp3)
# Different keys
mp4 = types.MappingProxyType({"a": 1, "b": 2, "d": 3})
assert not comparator(mp1, mp4)
# Different length
mp5 = types.MappingProxyType({"a": 1, "b": 2})
assert not comparator(mp1, mp5)
# Order doesn't matter (like dict)
mp6 = types.MappingProxyType({"c": 3, "a": 1, "b": 2})
assert comparator(mp1, mp6)
# Empty mappingproxy
empty1 = types.MappingProxyType({})
empty2 = types.MappingProxyType({})
assert comparator(empty1, empty2)
# Nested values
nested1 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
nested2 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
nested3 = types.MappingProxyType({"a": [1, 2, 4], "b": {"x": 1}})
assert comparator(nested1, nested2)
assert not comparator(nested1, nested3)
# mappingproxy is not equal to dict (different types)
d = {"a": 1, "b": 2}
mp = types.MappingProxyType({"a": 1, "b": 2})
assert not comparator(mp, d)
assert not comparator(d, mp)
# Verify class __dict__ is indeed a mappingproxy
class MyClass:
x = 1
y = 2
assert isinstance(MyClass.__dict__, types.MappingProxyType)
def test_mappingproxy_superset() -> None:
"""Test comparator superset_obj support for mappingproxy."""
import types
mp1 = types.MappingProxyType({"a": 1, "b": 2})
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
# mp2 is a superset of mp1
assert comparator(mp1, mp2, superset_obj=True)
# mp1 is not a superset of mp2
assert not comparator(mp2, mp1, superset_obj=True)
# Same mappingproxy with superset_obj=True
assert comparator(mp1, mp1, superset_obj=True)
# Different values even with superset
mp3 = types.MappingProxyType({"a": 1, "b": 99, "c": 3})
assert not comparator(mp1, mp3, superset_obj=True)
def test_tensorflow_tensor() -> None:
"""Test comparator support for TensorFlow Tensor objects."""
try:
@ -2911,16 +2989,378 @@ def test_numpy_dtypes() -> None:
assert not comparator(dtypes.Int32DType(), np.dtype('float32'))
def test_numpy_extended_precision_types() -> None:
"""Test comparator for numpy extended precision types like clongdouble."""
try:
import numpy as np
except ImportError:
pytest.skip("numpy not available")
# Test np.clongdouble (extended precision complex)
c1 = np.clongdouble(1 + 2j)
c2 = np.clongdouble(1 + 2j)
c3 = np.clongdouble(1 + 3j)
assert comparator(c1, c2)
assert not comparator(c1, c3)
# Test np.longdouble (extended precision float)
l1 = np.longdouble(1.5)
l2 = np.longdouble(1.5)
l3 = np.longdouble(2.5)
assert comparator(l1, l2)
assert not comparator(l1, l3)
# Test NaN handling for extended precision complex
nan_c1 = np.clongdouble(complex(np.nan, 2))
nan_c2 = np.clongdouble(complex(np.nan, 2))
assert comparator(nan_c1, nan_c2)
# Test NaN handling for extended precision float
nan_l1 = np.longdouble(np.nan)
nan_l2 = np.longdouble(np.nan)
assert comparator(nan_l1, nan_l2)
def test_numpy_typing_types() -> None:
"""Test comparator for numpy.typing types like NDArray type aliases."""
try:
import numpy as np
import numpy.typing as npt
except ImportError:
pytest.skip("numpy or numpy.typing not available")
# Test NDArray type alias comparisons
arr_type1 = npt.NDArray[np.float64]
arr_type2 = npt.NDArray[np.float64]
arr_type3 = npt.NDArray[np.int64]
assert comparator(arr_type1, arr_type2)
assert not comparator(arr_type1, arr_type3)
# Test NBitBase (if it can be instantiated)
try:
nbit1 = npt.NBitBase()
nbit2 = npt.NBitBase()
# NBitBase instances with empty __dict__ should compare as equal
assert comparator(nbit1, nbit2)
# Also test with superset_obj=True
assert comparator(nbit1, nbit2, superset_obj=True)
except TypeError:
# NBitBase may not be instantiable in all numpy versions
pass
def test_numpy_typing_superset_obj() -> None:
"""Test comparator with superset_obj=True for numpy types."""
try:
import numpy as np
import numpy.typing as npt
except ImportError:
pytest.skip("numpy or numpy.typing not available")
# Test numpy arrays with object dtype containing dicts (superset scenario)
a1 = np.array([{'a': 1}], dtype=object)
a2 = np.array([{'a': 1, 'b': 2}], dtype=object) # superset
assert comparator(a1, a2, superset_obj=True)
assert not comparator(a1, a2, superset_obj=False)
# Test extended precision types with superset_obj=True
c1 = np.clongdouble(1 + 2j)
c2 = np.clongdouble(1 + 2j)
assert comparator(c1, c2, superset_obj=True)
l1 = np.longdouble(1.5)
l2 = np.longdouble(1.5)
assert comparator(l1, l2, superset_obj=True)
# Test NDArray type alias with superset_obj=True
arr_type1 = npt.NDArray[np.float64]
arr_type2 = npt.NDArray[np.float64]
assert comparator(arr_type1, arr_type2, superset_obj=True)
# Test numpy structured arrays (np.void) with superset_obj=True
dt = np.dtype([('name', 'S10'), ('age', np.int32)])
a_struct = np.array([('Alice', 25)], dtype=dt)
b_struct = np.array([('Alice', 25)], dtype=dt)
assert comparator(a_struct[0], b_struct[0], superset_obj=True)
# Test numpy random generators with superset_obj=True
rng1 = np.random.default_rng(seed=42)
rng2 = np.random.default_rng(seed=42)
assert comparator(rng1, rng2, superset_obj=True)
rs1 = np.random.RandomState(seed=42)
rs2 = np.random.RandomState(seed=42)
assert comparator(rs1, rs2, superset_obj=True)
def test_numba_typed_list() -> None:
"""Test comparator for numba.typed.List."""
try:
import numba
from numba.typed import List as NumbaList
except ImportError:
pytest.skip("numba not available")
# Test equal lists
a = NumbaList([1, 2, 3])
b = NumbaList([1, 2, 3])
assert comparator(a, b)
# Test different values
c = NumbaList([1, 2, 4])
assert not comparator(a, c)
# Test different lengths
d = NumbaList([1, 2, 3, 4])
assert not comparator(a, d)
# Test empty lists
e = NumbaList.empty_list(item_type=numba.int64)
f = NumbaList.empty_list(item_type=numba.int64)
assert comparator(e, f)
# Test nested values (floats)
g = NumbaList([1.0, 2.0, 3.0])
h = NumbaList([1.0, 2.0, 3.0])
assert comparator(g, h)
i = NumbaList([1.0, 2.0, 4.0])
assert not comparator(g, i)
def test_numba_typed_dict() -> None:
"""Test comparator for numba.typed.Dict."""
try:
import numba
from numba.typed import Dict as NumbaDict
except ImportError:
pytest.skip("numba not available")
# Test equal dicts
a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
a["x"] = 1
a["y"] = 2
b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
b["x"] = 1
b["y"] = 2
assert comparator(a, b)
# Test different values
c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
c["x"] = 1
c["y"] = 3
assert not comparator(a, c)
# Test different keys
d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
d["x"] = 1
d["z"] = 2
assert not comparator(a, d)
# Test different lengths
e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
e["x"] = 1
assert not comparator(a, e)
# Test empty dicts
f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
assert comparator(f, g)
def test_numba_types() -> None:
"""Test comparator for numba type objects."""
try:
import numba
from numba import types
except ImportError:
pytest.skip("numba not available")
# Test basic numeric types from numba module
assert comparator(numba.int64, numba.int64)
assert comparator(numba.float64, numba.float64)
assert comparator(numba.int32, numba.int32)
assert comparator(numba.float32, numba.float32)
# Test basic numeric types from numba.types module
assert comparator(types.int64, types.int64)
assert comparator(types.float64, types.float64)
assert comparator(types.int8, types.int8)
assert comparator(types.int16, types.int16)
assert comparator(types.uint8, types.uint8)
assert comparator(types.uint16, types.uint16)
assert comparator(types.uint32, types.uint32)
assert comparator(types.uint64, types.uint64)
assert comparator(types.complex64, types.complex64)
assert comparator(types.complex128, types.complex128)
# Test different types
assert not comparator(numba.int64, numba.float64)
assert not comparator(numba.int32, numba.int64)
assert not comparator(numba.float32, numba.float64)
assert not comparator(types.int8, types.int16)
assert not comparator(types.uint32, types.int32)
assert not comparator(types.complex64, types.complex128)
# Test boolean type
assert comparator(numba.boolean, numba.boolean)
assert comparator(types.boolean, types.boolean)
assert not comparator(numba.boolean, numba.int64)
# Test special types
assert comparator(types.none, types.none)
assert comparator(types.void, types.void)
assert comparator(types.pyobject, types.pyobject)
assert comparator(types.unicode_type, types.unicode_type)
# Note: types.none and types.void are the same object in numba
assert comparator(types.none, types.void)
assert not comparator(types.unicode_type, types.pyobject)
assert not comparator(types.none, types.int64)
# Test array types
arr_type1 = types.Array(numba.float64, 1, 'C')
arr_type2 = types.Array(numba.float64, 1, 'C')
arr_type3 = types.Array(numba.float64, 2, 'C')
arr_type4 = types.Array(numba.int64, 1, 'C')
arr_type5 = types.Array(numba.float64, 1, 'F') # Fortran order
assert comparator(arr_type1, arr_type2)
assert not comparator(arr_type1, arr_type3) # different ndim
assert not comparator(arr_type1, arr_type4) # different dtype
assert not comparator(arr_type1, arr_type5) # different layout
# Test tuple types
tuple_type1 = types.UniTuple(types.int64, 3)
tuple_type2 = types.UniTuple(types.int64, 3)
tuple_type3 = types.UniTuple(types.int64, 4)
tuple_type4 = types.UniTuple(types.float64, 3)
assert comparator(tuple_type1, tuple_type2)
assert not comparator(tuple_type1, tuple_type3) # different count
assert not comparator(tuple_type1, tuple_type4) # different dtype
# Test heterogeneous tuple types
hetero_tuple1 = types.Tuple([types.int64, types.float64])
hetero_tuple2 = types.Tuple([types.int64, types.float64])
hetero_tuple3 = types.Tuple([types.int64, types.int64])
assert comparator(hetero_tuple1, hetero_tuple2)
assert not comparator(hetero_tuple1, hetero_tuple3)
# Test ListType and DictType
list_type1 = types.ListType(types.int64)
list_type2 = types.ListType(types.int64)
list_type3 = types.ListType(types.float64)
assert comparator(list_type1, list_type2)
assert not comparator(list_type1, list_type3)
dict_type1 = types.DictType(types.unicode_type, types.int64)
dict_type2 = types.DictType(types.unicode_type, types.int64)
dict_type3 = types.DictType(types.unicode_type, types.float64)
dict_type4 = types.DictType(types.int64, types.int64)
assert comparator(dict_type1, dict_type2)
assert not comparator(dict_type1, dict_type3) # different value type
assert not comparator(dict_type1, dict_type4) # different key type
def test_numba_jit_functions() -> None:
"""Test comparator for numba JIT-compiled functions."""
try:
from numba import jit
except ImportError:
pytest.skip("numba not available")
@jit(nopython=True)
def add(x, y):
return x + y
@jit(nopython=True)
def add2(x, y):
return x + y
@jit(nopython=True)
def multiply(x, y):
return x * y
# Compile the functions by calling them
add(1, 2)
add2(1, 2)
multiply(1, 2)
# Same function should compare equal to itself
assert comparator(add, add)
# Different functions (even with same code) should not compare equal
# since they are distinct function objects
assert not comparator(add, add2)
# Different functions with different code should not compare equal
assert not comparator(add, multiply)
def test_numba_superset_obj() -> None:
"""Test comparator for numba types with superset_obj=True."""
try:
import numba
from numba.typed import Dict as NumbaDict
from numba.typed import List as NumbaList
except ImportError:
pytest.skip("numba not available")
# Test NumbaDict with superset_obj=True
orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
orig_dict["x"] = 1
orig_dict["y"] = 2
# New dict with same keys - should pass
new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_same["x"] = 1
new_dict_same["y"] = 2
assert comparator(orig_dict, new_dict_same, superset_obj=True)
# New dict with extra keys - should pass with superset_obj=True
new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_superset["x"] = 1
new_dict_superset["y"] = 2
new_dict_superset["z"] = 3
assert comparator(orig_dict, new_dict_superset, superset_obj=True)
# But should fail with superset_obj=False
assert not comparator(orig_dict, new_dict_superset, superset_obj=False)
# New dict missing keys - should fail even with superset_obj=True
new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_subset["x"] = 1
assert not comparator(orig_dict, new_dict_subset, superset_obj=True)
# New dict with different values - should fail
new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_diff["x"] = 1
new_dict_diff["y"] = 99
assert not comparator(orig_dict, new_dict_diff, superset_obj=True)
# Test NumbaList with superset_obj=True (lists don't support superset semantics)
orig_list = NumbaList([1, 2, 3])
new_list_same = NumbaList([1, 2, 3])
new_list_longer = NumbaList([1, 2, 3, 4])
assert comparator(orig_list, new_list_same, superset_obj=True)
# Lists must have same length regardless of superset_obj
assert not comparator(orig_list, new_list_longer, superset_obj=True)
# Test empty dict with superset_obj=True
empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
non_empty_new["a"] = 1
# Empty orig should match any superset
assert comparator(empty_orig, non_empty_new, superset_obj=True)
assert not comparator(empty_orig, non_empty_new, superset_obj=False)
# =============================================================================
# Tests for pytest temp path normalization (lines 28-69 in comparator.py)
# =============================================================================
from codeflash.verification.comparator import (
PYTEST_TEMP_PATH_PATTERN,
_is_temp_path,
_normalize_temp_path,
)
class TestIsTempPath:
"""Tests for the _is_temp_path() function."""

View file

@ -5,6 +5,7 @@ from unittest.mock import Mock
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.models.models import (
CodeOptimizationContext,
ConcurrencyMetrics,
CoverageData,
CoverageStatus,
FunctionCoverage,
@ -15,12 +16,14 @@ from codeflash.models.models import (
TestType,
)
from codeflash.result.critic import (
concurrency_gain,
coverage_critic,
performance_gain,
quantity_of_tests_critic,
speedup_critic,
throughput_gain,
)
from codeflash.verification.parse_test_output import parse_concurrency_metrics
def test_performance_gain() -> None:
@ -569,3 +572,238 @@ def test_speedup_critic_with_async_throughput() -> None:
best_throughput_until_now=None,
disable_gh_action_noise=True
)
def test_concurrency_gain() -> None:
"""Test concurrency_gain calculation."""
# Test basic concurrency improvement (blocking -> non-blocking)
original = ConcurrencyMetrics(
sequential_time_ns=10_000_000, # 10ms
concurrent_time_ns=10_000_000, # 10ms (no speedup - blocking)
concurrency_factor=10,
concurrency_ratio=1.0, # sequential/concurrent = 1.0
)
optimized = ConcurrencyMetrics(
sequential_time_ns=10_000_000, # 10ms
concurrent_time_ns=1_000_000, # 1ms (10x speedup - non-blocking)
concurrency_factor=10,
concurrency_ratio=10.0, # sequential/concurrent = 10.0
)
# 900% improvement: (10 - 1) / 1 = 9.0
assert concurrency_gain(original, optimized) == 9.0
# Test no improvement
same = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0,
)
assert concurrency_gain(original, same) == 0.0
# Test slight improvement
slightly_better = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=8_000_000,
concurrency_factor=10,
concurrency_ratio=1.25,
)
# 25% improvement: (1.25 - 1.0) / 1.0 = 0.25
assert concurrency_gain(original, slightly_better) == 0.25
# Test zero original ratio (edge case)
zero_ratio = ConcurrencyMetrics(
sequential_time_ns=0,
concurrent_time_ns=1_000_000,
concurrency_factor=10,
concurrency_ratio=0.0,
)
assert concurrency_gain(zero_ratio, optimized) == 0.0
def test_speedup_critic_with_concurrency_metrics() -> None:
"""Test speedup_critic with concurrency metrics evaluation."""
original_code_runtime = 10000 # 10 microseconds
original_async_throughput = 100
# Original concurrency metrics (blocking code - ratio ~= 1.0)
original_concurrency = ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0,
)
# Test case 1: Concurrency improves significantly (blocking -> non-blocking)
candidate_result = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=10000, # Same runtime
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=10000,
async_throughput=100, # Same throughput
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=1_000_000, # 10x faster concurrent execution
concurrency_factor=10,
concurrency_ratio=10.0, # 900% improvement
),
)
# Should pass due to concurrency improvement even though runtime/throughput unchanged
assert speedup_critic(
candidate_result=candidate_result,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=None,
disable_gh_action_noise=True,
)
# Test case 2: No concurrency improvement (should fall back to other metrics)
candidate_result_no_conc = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=8000, # 20% runtime improvement
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=8000,
async_throughput=100,
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=10_000_000,
concurrency_factor=10,
concurrency_ratio=1.0, # No improvement
),
)
# Should pass due to runtime improvement
assert speedup_critic(
candidate_result=candidate_result_no_conc,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=None,
disable_gh_action_noise=True,
)
# Test case 3: Concurrency below threshold (20% required)
candidate_result_below_threshold = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=10000, # Same runtime
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=10000,
async_throughput=100, # Same throughput
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=9_000_000, # Only 11% improvement
concurrency_factor=10,
concurrency_ratio=1.11,
),
)
# Should fail - no metric improves enough
assert not speedup_critic(
candidate_result=candidate_result_below_threshold,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=None,
disable_gh_action_noise=True,
)
# Test case 4: best_concurrency_ratio_until_now comparison
candidate_result_good = OptimizedCandidateResult(
max_loop_count=5,
best_test_runtime=10000,
behavior_test_results=TestResults(),
benchmarking_test_results=TestResults(),
optimization_candidate_index=0,
total_candidate_timing=10000,
async_throughput=100,
concurrency_metrics=ConcurrencyMetrics(
sequential_time_ns=10_000_000,
concurrent_time_ns=2_000_000,
concurrency_factor=10,
concurrency_ratio=5.0,
),
)
# Should fail when there's a better concurrency ratio already
assert not speedup_critic(
candidate_result=candidate_result_good,
original_code_runtime=original_code_runtime,
best_runtime_until_now=None,
original_async_throughput=original_async_throughput,
best_throughput_until_now=None,
original_concurrency_metrics=original_concurrency,
best_concurrency_ratio_until_now=10.0, # Better ratio already exists
disable_gh_action_noise=True,
)
def test_concurrency_ratio_display_formatting() -> None:
orig_ratio = 0.05
cand_ratio = 0.15
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
assert display_string == "Concurrency ratio: 0.05x → 0.15x (+200.0%)"
orig_ratio = 1.0
cand_ratio = 10.0
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
assert display_string == "Concurrency ratio: 1.00x → 10.00x (+900.0%)"
orig_ratio = 0.01
cand_ratio = 0.03
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
assert display_string == "Concurrency ratio: 0.01x → 0.03x (+200.0%)"
def test_parse_concurrency_metrics() -> None:
"""Test parse_concurrency_metrics function."""
# Test with valid concurrency output
stdout = (
"!@######CONC:test_module:TestClass:test_func:my_function:0:10000000:1000000:10######@!\n"
"!@######CONC:test_module:TestClass:test_func:my_function:1:10000000:1000000:10######@!\n"
)
test_results = TestResults(perf_stdout=stdout)
metrics = parse_concurrency_metrics(test_results, "my_function")
assert metrics is not None
assert metrics.sequential_time_ns == 10_000_000 # Average of both matches
assert metrics.concurrent_time_ns == 1_000_000
assert metrics.concurrency_factor == 10
assert metrics.concurrency_ratio == 10.0 # 10000000 / 1000000
# Test with no matching function
metrics_wrong_func = parse_concurrency_metrics(test_results, "other_function")
assert metrics_wrong_func is None
# Test with empty stdout
empty_results = TestResults(perf_stdout="")
metrics_empty = parse_concurrency_metrics(empty_results, "my_function")
assert metrics_empty is None
# Test with None stdout
none_results = TestResults(perf_stdout=None)
metrics_none = parse_concurrency_metrics(none_results, "my_function")
assert metrics_none is None
# Test with no class name
stdout_no_class = "!@######CONC:test_module::test_func:my_function:0:5000000:2500000:10######@!\n"
test_results_no_class = TestResults(perf_stdout=stdout_no_class)
metrics_no_class = parse_concurrency_metrics(test_results_no_class, "my_function")
assert metrics_no_class is not None
assert metrics_no_class.concurrency_ratio == 2.0 # 5000000 / 2500000

View file

@ -218,8 +218,28 @@ def test_no_targets_found() -> None:
def target(self):
pass
"""
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
expected = dedent("""
class MyClass:
def method(self):
pass
class Inner:
def target(self):
pass
""")
assert result.strip() == expected.strip()
def test_no_targets_found_raises_for_nonexistent() -> None:
"""Test that ValueError is raised when the target function doesn't exist at all."""
code = """
class MyClass:
def method(self):
pass
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"NonExistent.target"})
def test_module_var() -> None:

View file

@ -131,6 +131,46 @@ async def async_function(x: int, y: int) -> int:
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_decorator_application_concurrency_mode(temp_dir):
"""Test that CONCURRENCY mode applies the codeflash_concurrency_async decorator."""
async_function_code = '''
import asyncio
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
expected_decorated_code = '''
import asyncio
from codeflash.code_utils.codeflash_wrap_decorator import \\
codeflash_concurrency_async
@codeflash_concurrency_async
async def async_function(x: int, y: int) -> int:
"""Simple async function for testing."""
await asyncio.sleep(0.01)
return x * y
'''
test_file = temp_dir / "test_async.py"
test_file.write_text(async_function_code)
func = FunctionToOptimize(
function_name="async_function", file_path=test_file, parents=[], is_async=True
)
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.CONCURRENCY)
assert decorator_added
modified_code = test_file.read_text()
assert modified_code.strip() == expected_decorated_code.strip()
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
def test_async_class_method_decorator_application(temp_dir):
async_class_code = '''

View file

@ -124,7 +124,7 @@ def _get_string_usage(text: str) -> Usage:
helper_file.unlink(missing_ok=True)
main_file.unlink(missing_ok=True)
expected_helper = """import re
from collections.abc import Sequence

View file

@ -481,3 +481,86 @@ def unused_function():
qualified_functions = {"get_platform_info", "get_loop_result"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
assert result.strip() == expected.strip()
def test_enum_attribute_access_dependency() -> None:
"""Test that enum/class attribute access like MessageKind.VALUE is tracked as a dependency."""
code = """
from enum import Enum
class MessageKind(Enum):
VALUE = "value"
OTHER = "other"
class UnusedEnum(Enum):
UNUSED = "unused"
UNUSED_VAR = 123
def process_message(kind):
match kind:
case MessageKind.VALUE:
return "got value"
case MessageKind.OTHER:
return "got other"
return "unknown"
"""
expected = """
from enum import Enum
class MessageKind(Enum):
VALUE = "value"
OTHER = "other"
class UnusedEnum(Enum):
UNUSED = "unused"
def process_message(kind):
match kind:
case MessageKind.VALUE:
return "got value"
case MessageKind.OTHER:
return "got other"
return "unknown"
"""
qualified_functions = {"process_message"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# MessageKind should be preserved because process_message uses MessageKind.VALUE
assert "class MessageKind" in result
# UNUSED_VAR should be removed
assert "UNUSED_VAR" not in result
assert result.strip() == expected.strip()
def test_attribute_access_does_not_track_attr_name() -> None:
"""Test that self.x attribute access doesn't track 'x' as a dependency on module-level x."""
code = """
x = "module_level_x"
UNUSED_VAR = "unused"
class MyClass:
def __init__(self):
self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x'
def get_x(self):
return self.x # This 'x' is also an attribute access
"""
expected = """
class MyClass:
def __init__(self):
self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x'
def get_x(self):
return self.x # This 'x' is also an attribute access
"""
qualified_functions = {"MyClass.get_x", "MyClass.__init__"}
result = remove_unused_definitions_by_function_names(code, qualified_functions)
# Module-level x should NOT be kept (self.x doesn't reference it)
assert 'x = "module_level_x"' not in result
# UNUSED_VAR should also be removed
assert "UNUSED_VAR" not in result
assert result.strip() == expected.strip()

11
uv.lock
View file

@ -925,7 +925,7 @@ name = "exceptiongroup"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
wheels = [
@ -5473,11 +5473,14 @@ wheels = [
[[package]]
name = "wheel"
version = "0.45.1"
version = "0.46.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545, upload-time = "2024-11-23T00:18:23.513Z" }
dependencies = [
{ name = "packaging" },
]
sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3a64fa9639b8e290fe8630d8067a66f7c5510845c6d73686ad880c9b04d9/wheel-0.46.2.tar.gz", hash = "sha256:3d79e48fde9847618a5a181f3cc35764c349c752e2fe911e65fa17faab9809b0", size = 60274, upload-time = "2026-01-21T23:55:25.838Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494, upload-time = "2024-11-23T00:18:21.207Z" },
{ url = "https://files.pythonhosted.org/packages/13/2c/5e079cefe955ae58e5a052fe037c850ce493eb7269dedeb960237e78fb0f/wheel-0.46.2-py3-none-any.whl", hash = "sha256:33ae60725d69eaa249bc1982e739943c23b34b58d51f1cb6253453773aca6e65", size = 29971, upload-time = "2026-01-21T23:55:24.447Z" },
]
[[package]]