mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into jit-docs
This commit is contained in:
commit
2dcfba6949
46 changed files with 4160 additions and 1657 deletions
3
.github/workflows/mypy.yml
vendored
3
.github/workflows/mypy.yml
vendored
|
|
@ -22,11 +22,10 @@ jobs:
|
|||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: "0.5.30"
|
||||
|
||||
- name: sync uv
|
||||
run: |
|
||||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -259,3 +259,6 @@ WARP.MD
|
|||
.mcp.json
|
||||
.tessl/
|
||||
tessl.json
|
||||
|
||||
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
|
||||
AGENTS.md
|
||||
|
|
|
|||
|
|
@ -88,6 +88,10 @@ else:
|
|||
- Commit message body should be concise (1-2 sentences max)
|
||||
- PR titles should also use conventional format
|
||||
|
||||
<!-- Section below is auto-generated by `tessl install` - do not edit manually -->
|
||||
|
||||
# Agent Rules <!-- tessl-managed -->
|
||||
|
||||
@.tessl/RULES.md follow the [instructions](.tessl/RULES.md)
|
||||
|
||||
@AGENTS.md
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Business Source License 1.1
|
|||
Parameters
|
||||
|
||||
Licensor: CodeFlash Inc.
|
||||
Licensed Work: Codeflash Client version 0.19.x
|
||||
Licensed Work: Codeflash Client version 0.20.x
|
||||
The Licensed Work is (c) 2024 CodeFlash Inc.
|
||||
|
||||
Additional Use Grant: None. Production use of the Licensed Work is only permitted
|
||||
|
|
@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
|
|||
Platform. Please visit codeflash.ai for further
|
||||
information.
|
||||
|
||||
Change Date: 2029-12-21
|
||||
Change Date: 2030-01-26
|
||||
|
||||
Change License: MIT
|
||||
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ class AiServiceClient:
|
|||
logger.info("!lsp|Rewriting as a JIT function…")
|
||||
console.rule()
|
||||
try:
|
||||
response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60)
|
||||
response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=self.timeout)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating jit rewritten candidate: {e}")
|
||||
ph("cli-jit-rewrite-error-caught", {"error": str(e)})
|
||||
|
|
@ -460,6 +460,10 @@ class AiServiceClient:
|
|||
optimized_throughput: str | None = None,
|
||||
throughput_improvement: str | None = None,
|
||||
function_references: str | None = None,
|
||||
acceptance_reason: str | None = None,
|
||||
original_concurrency_ratio: str | None = None,
|
||||
optimized_concurrency_ratio: str | None = None,
|
||||
concurrency_improvement: str | None = None,
|
||||
codeflash_version: str = codeflash_version,
|
||||
) -> str:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
|
@ -480,8 +484,12 @@ class AiServiceClient:
|
|||
- original_throughput: str | None - throughput for the baseline code (operations per second)
|
||||
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
|
||||
- throughput_improvement: str | None - throughput improvement percentage
|
||||
- current codeflash version
|
||||
- function_references: str | None - where the function is called in the codebase
|
||||
- acceptance_reason: str | None - why the optimization was accepted (runtime, throughput, or concurrency)
|
||||
- original_concurrency_ratio: str | None - concurrency ratio for the baseline code
|
||||
- optimized_concurrency_ratio: str | None - concurrency ratio for the optimized code
|
||||
- concurrency_improvement: str | None - concurrency improvement percentage
|
||||
- codeflash_version: str - current codeflash version
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
|
@ -505,6 +513,10 @@ class AiServiceClient:
|
|||
"optimized_throughput": optimized_throughput,
|
||||
"throughput_improvement": throughput_improvement,
|
||||
"function_references": function_references,
|
||||
"acceptance_reason": acceptance_reason,
|
||||
"original_concurrency_ratio": original_concurrency_ratio,
|
||||
"optimized_concurrency_ratio": optimized_concurrency_ratio,
|
||||
"concurrency_improvement": concurrency_improvement,
|
||||
"codeflash_version": codeflash_version,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,6 +84,9 @@ def parse_args() -> Namespace:
|
|||
parser.add_argument(
|
||||
"--no-gen-tests", action="store_true", help="Do not generate tests, use only existing tests for optimization."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-jit-opts", action="store_true", help="Do not generate JIT-compiled optimizations for numerical code."
|
||||
)
|
||||
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
|
||||
parser.add_argument(
|
||||
"--verify-setup",
|
||||
|
|
|
|||
|
|
@ -25,12 +25,117 @@ if TYPE_CHECKING:
|
|||
from codeflash.models.models import FunctionSource
|
||||
|
||||
|
||||
class GlobalFunctionCollector(cst.CSTVisitor):
|
||||
"""Collects all module-level function definitions (not inside classes or other functions)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.functions: dict[str, cst.FunctionDef] = {}
|
||||
self.function_order: list[str] = []
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
|
||||
if self.scope_depth == 0:
|
||||
# Module-level function
|
||||
name = node.name.value
|
||||
self.functions[name] = node
|
||||
if name not in self.function_order:
|
||||
self.function_order.append(name)
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
self.scope_depth -= 1
|
||||
|
||||
|
||||
class GlobalFunctionTransformer(cst.CSTTransformer):
|
||||
"""Transforms/adds module-level functions from the new file to the original file."""
|
||||
|
||||
def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None:
|
||||
super().__init__()
|
||||
self.new_functions = new_functions
|
||||
self.new_function_order = new_function_order
|
||||
self.processed_functions: set[str] = set()
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
self.scope_depth -= 1
|
||||
if self.scope_depth > 0:
|
||||
return updated_node
|
||||
|
||||
# Check if this is a module-level function we need to replace
|
||||
name = original_node.name.value
|
||||
if name in self.new_functions:
|
||||
self.processed_functions.add(name)
|
||||
return self.new_functions[name]
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# Add any new functions that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
functions_to_append = [
|
||||
self.new_functions[name]
|
||||
for name in self.new_function_order
|
||||
if name not in self.processed_functions and name in self.new_functions
|
||||
]
|
||||
|
||||
if functions_to_append:
|
||||
# Find the position of the last function or class definition
|
||||
insert_index = find_insertion_index_after_imports(updated_node)
|
||||
for i, stmt in enumerate(new_statements):
|
||||
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
||||
insert_index = i + 1
|
||||
|
||||
# Add empty line before each new function
|
||||
function_nodes = []
|
||||
for func in functions_to_append:
|
||||
func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines])
|
||||
function_nodes.append(func_with_empty_line)
|
||||
|
||||
new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:]))
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
def collect_referenced_names(node: cst.CSTNode) -> set[str]:
|
||||
"""Collect all names referenced in a CST node using recursive traversal."""
|
||||
names: set[str] = set()
|
||||
|
||||
def _collect(n: cst.CSTNode) -> None:
|
||||
if isinstance(n, cst.Name):
|
||||
names.add(n.value)
|
||||
# Recursively process all children
|
||||
for child in n.children:
|
||||
_collect(child)
|
||||
|
||||
_collect(node)
|
||||
return names
|
||||
|
||||
|
||||
class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||
"""Collects all global assignment statements."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.assignments: dict[str, cst.Assign] = {}
|
||||
self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {}
|
||||
self.assignment_order: list[str] = []
|
||||
# Track scope depth to identify global assignments
|
||||
self.scope_depth = 0
|
||||
|
|
@ -72,6 +177,21 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
|||
self.assignment_order.append(name)
|
||||
return True
|
||||
|
||||
def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]:
|
||||
# Handle annotated assignments like: _CACHE: Dict[str, int] = {}
|
||||
# Only process module-level annotated assignments with a value
|
||||
if (
|
||||
self.scope_depth == 0
|
||||
and self.if_else_depth == 0
|
||||
and isinstance(node.target, cst.Name)
|
||||
and node.value is not None
|
||||
):
|
||||
name = node.target.value
|
||||
self.assignments[name] = node
|
||||
if name not in self.assignment_order:
|
||||
self.assignment_order.append(name)
|
||||
return True
|
||||
|
||||
|
||||
def find_insertion_index_after_imports(node: cst.Module) -> int:
|
||||
"""Find the position of the last import statement in the top-level of the module."""
|
||||
|
|
@ -103,7 +223,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int:
|
|||
class GlobalAssignmentTransformer(cst.CSTTransformer):
|
||||
"""Transforms global assignments in the original file with those from the new file."""
|
||||
|
||||
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None:
|
||||
def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None:
|
||||
super().__init__()
|
||||
self.new_assignments = new_assignments
|
||||
self.new_assignment_order = new_assignment_order
|
||||
|
|
@ -150,38 +270,120 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode:
|
||||
if self.scope_depth > 0 or self.if_else_depth > 0:
|
||||
return updated_node
|
||||
|
||||
# Check if this is a global annotated assignment we need to replace
|
||||
if isinstance(original_node.target, cst.Name):
|
||||
name = original_node.target.value
|
||||
if name in self.new_assignments:
|
||||
self.processed_assignments.add(name)
|
||||
return self.new_assignments[name]
|
||||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
# Find assignments to append
|
||||
assignments_to_append = [
|
||||
self.new_assignments[name]
|
||||
(name, self.new_assignments[name])
|
||||
for name in self.new_assignment_order
|
||||
if name not in self.processed_assignments and name in self.new_assignments
|
||||
]
|
||||
|
||||
if assignments_to_append:
|
||||
# after last top-level imports
|
||||
if not assignments_to_append:
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
# Collect all class and function names defined in the module
|
||||
# These are the names that assignments might reference
|
||||
module_defined_names: set[str] = set()
|
||||
for stmt in new_statements:
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
module_defined_names.add(stmt.name.value)
|
||||
|
||||
# Partition assignments: those that reference module definitions go at the end,
|
||||
# those that don't can go right after imports
|
||||
assignments_after_imports: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
|
||||
assignments_after_definitions: list[tuple[str, cst.Assign | cst.AnnAssign]] = []
|
||||
|
||||
for name, assignment in assignments_to_append:
|
||||
# Get the value being assigned
|
||||
if isinstance(assignment, (cst.Assign, cst.AnnAssign)) and assignment.value is not None:
|
||||
value_node = assignment.value
|
||||
else:
|
||||
# No value to analyze, safe to place after imports
|
||||
assignments_after_imports.append((name, assignment))
|
||||
continue
|
||||
|
||||
# Collect names referenced in the assignment value
|
||||
referenced_names = collect_referenced_names(value_node)
|
||||
|
||||
# Check if any referenced names are module-level definitions
|
||||
if referenced_names & module_defined_names:
|
||||
# This assignment references a class/function, place it after definitions
|
||||
assignments_after_definitions.append((name, assignment))
|
||||
else:
|
||||
# Safe to place right after imports
|
||||
assignments_after_imports.append((name, assignment))
|
||||
|
||||
# Insert assignments that don't depend on module definitions right after imports
|
||||
if assignments_after_imports:
|
||||
insert_index = find_insertion_index_after_imports(updated_node)
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for _, assignment in assignments_after_imports
|
||||
]
|
||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||
|
||||
# Insert assignments that depend on module definitions after all class/function definitions
|
||||
if assignments_after_definitions:
|
||||
# Find the position after the last function or class definition
|
||||
insert_index = find_insertion_index_after_imports(cst.Module(body=new_statements))
|
||||
for i, stmt in enumerate(new_statements):
|
||||
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
||||
insert_index = i + 1
|
||||
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for assignment in assignments_to_append
|
||||
for _, assignment in assignments_after_definitions
|
||||
]
|
||||
|
||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||
|
||||
# Add a blank line after the last assignment if needed
|
||||
after_index = insert_index + len(assignment_lines)
|
||||
if after_index < len(new_statements):
|
||||
next_stmt = new_statements[after_index]
|
||||
# If there's no empty line, add one
|
||||
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
|
||||
if not has_empty:
|
||||
new_statements[after_index] = next_stmt.with_changes(
|
||||
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
|
||||
)
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
class GlobalStatementTransformer(cst.CSTTransformer):
|
||||
"""Transformer that appends global statements at the end of the module.
|
||||
|
||||
This ensures that global statements (like function calls at module level) are placed
|
||||
after all functions, classes, and assignments they might reference, preventing NameError
|
||||
at module load time.
|
||||
|
||||
This transformer should be run LAST after GlobalFunctionTransformer and
|
||||
GlobalAssignmentTransformer have already added their content.
|
||||
"""
|
||||
|
||||
def __init__(self, global_statements: list[cst.SimpleStatementLine]) -> None:
|
||||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
if not self.global_statements:
|
||||
return updated_node
|
||||
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
# Add empty line before each statement for readability
|
||||
statement_lines = [
|
||||
stmt.with_changes(leading_lines=[cst.EmptyLine(), *stmt.leading_lines]) for stmt in self.global_statements
|
||||
]
|
||||
|
||||
# Append statements at the end of the module
|
||||
# This ensures they come after all functions, classes, and assignments
|
||||
new_statements.extend(statement_lines)
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
|
@ -213,8 +415,8 @@ class GlobalStatementCollector(cst.CSTVisitor):
|
|||
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
||||
if not self.in_function_or_class:
|
||||
for statement in node.body:
|
||||
# Skip imports
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
|
||||
# Skip imports and assignments (both regular and annotated)
|
||||
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign, cst.AnnAssign)):
|
||||
self.global_statements.append(node)
|
||||
break
|
||||
|
||||
|
|
@ -309,40 +511,6 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
self._collect_imports_from_block(node.body)
|
||||
|
||||
|
||||
class ImportInserter(cst.CSTTransformer):
|
||||
"""Transformer that inserts global statements after the last import."""
|
||||
|
||||
def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import_line: int) -> None:
|
||||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
self.last_import_line = last_import_line
|
||||
self.current_line = 0
|
||||
self.inserted = False
|
||||
|
||||
def leave_SimpleStatementLine(
|
||||
self,
|
||||
original_node: cst.SimpleStatementLine, # noqa: ARG002
|
||||
updated_node: cst.SimpleStatementLine,
|
||||
) -> cst.Module:
|
||||
self.current_line += 1
|
||||
|
||||
# If we're right after the last import and haven't inserted yet
|
||||
if self.current_line == self.last_import_line and not self.inserted:
|
||||
self.inserted = True
|
||||
return cst.Module(body=[updated_node, *self.global_statements])
|
||||
|
||||
return cst.Module(body=[updated_node])
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# If there were no imports, add at the beginning of the module
|
||||
if self.last_import_line == 0 and not self.inserted:
|
||||
updated_body = list(updated_node.body)
|
||||
for stmt in reversed(self.global_statements):
|
||||
updated_body.insert(0, stmt)
|
||||
return updated_node.with_changes(body=updated_body)
|
||||
return updated_node
|
||||
|
||||
|
||||
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
|
||||
"""Extract global statements from source code."""
|
||||
module = cst.parse_module(source_code)
|
||||
|
|
@ -394,34 +562,58 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
continue
|
||||
unique_global_statements.append(stmt)
|
||||
|
||||
mod_dst_code = dst_module_code
|
||||
# Insert unique global statements if any
|
||||
if unique_global_statements:
|
||||
last_import_line = find_last_import_line(dst_module_code)
|
||||
# Reuse already-parsed dst_module
|
||||
transformer = ImportInserter(unique_global_statements, last_import_line)
|
||||
# Use visit inplace, don't parse again
|
||||
modified_module = dst_module.visit(transformer)
|
||||
mod_dst_code = modified_module.code
|
||||
# Parse the code after insertion
|
||||
original_module = cst.parse_module(mod_dst_code)
|
||||
else:
|
||||
# No new statements to insert, reuse already-parsed dst_module
|
||||
original_module = dst_module
|
||||
# Reuse already-parsed dst_module
|
||||
original_module = dst_module
|
||||
|
||||
# Parse the src_module_code once only (already done above: src_module)
|
||||
# Collect assignments from the new file
|
||||
new_collector = GlobalAssignmentCollector()
|
||||
src_module.visit(new_collector)
|
||||
# Only create transformer if there are assignments to insert/transform
|
||||
if not new_collector.assignments: # nothing to transform
|
||||
return mod_dst_code
|
||||
new_assignment_collector = GlobalAssignmentCollector()
|
||||
src_module.visit(new_assignment_collector)
|
||||
|
||||
# Transform the original destination module
|
||||
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
|
||||
transformed_module = original_module.visit(transformer)
|
||||
# Collect module-level functions from both source and destination
|
||||
src_function_collector = GlobalFunctionCollector()
|
||||
src_module.visit(src_function_collector)
|
||||
|
||||
return transformed_module.code
|
||||
dst_function_collector = GlobalFunctionCollector()
|
||||
original_module.visit(dst_function_collector)
|
||||
|
||||
# Filter out functions that already exist in the destination (only add truly new functions)
|
||||
new_functions = {
|
||||
name: func
|
||||
for name, func in src_function_collector.functions.items()
|
||||
if name not in dst_function_collector.functions
|
||||
}
|
||||
new_function_order = [name for name in src_function_collector.function_order if name in new_functions]
|
||||
|
||||
# If there are no assignments, no new functions, and no global statements, return unchanged
|
||||
if not new_assignment_collector.assignments and not new_functions and not unique_global_statements:
|
||||
return dst_module_code
|
||||
|
||||
# The order of transformations matters:
|
||||
# 1. Functions first - so assignments and statements can reference them
|
||||
# 2. Assignments second - so they come after functions but before statements
|
||||
# 3. Global statements last - so they can reference both functions and assignments
|
||||
|
||||
# Transform functions if any
|
||||
if new_functions:
|
||||
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
|
||||
original_module = original_module.visit(function_transformer)
|
||||
|
||||
# Transform assignments if any
|
||||
if new_assignment_collector.assignments:
|
||||
transformer = GlobalAssignmentTransformer(
|
||||
new_assignment_collector.assignments, new_assignment_collector.assignment_order
|
||||
)
|
||||
original_module = original_module.visit(transformer)
|
||||
|
||||
# Insert global statements (like function calls at module level) LAST,
|
||||
# after all functions and assignments are added, to ensure they can reference any
|
||||
# functions or variables defined in the module
|
||||
if unique_global_statements:
|
||||
statement_transformer = GlobalStatementTransformer(unique_global_statements)
|
||||
original_module = original_module.visit(statement_transformer)
|
||||
|
||||
return original_module.code
|
||||
|
||||
|
||||
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import gc
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
|
@ -165,3 +166,45 @@ def codeflash_performance_async(func: F) -> F:
|
|||
return return_value
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def codeflash_concurrency_async(func: F) -> F:
|
||||
"""Measures concurrent vs sequential execution performance for async functions."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
|
||||
|
||||
test_module_name = os.environ.get("CODEFLASH_TEST_MODULE", "")
|
||||
test_class_name = os.environ.get("CODEFLASH_TEST_CLASS", "")
|
||||
test_function = os.environ.get("CODEFLASH_TEST_FUNCTION", "")
|
||||
loop_index = os.environ.get("CODEFLASH_LOOP_INDEX", "0")
|
||||
|
||||
# Phase 1: Sequential execution timing
|
||||
gc.disable()
|
||||
try:
|
||||
seq_start = time.perf_counter_ns()
|
||||
for _ in range(concurrency_factor):
|
||||
result = await func(*args, **kwargs)
|
||||
sequential_time = time.perf_counter_ns() - seq_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
# Phase 2: Concurrent execution timing
|
||||
gc.disable()
|
||||
try:
|
||||
conc_start = time.perf_counter_ns()
|
||||
tasks = [func(*args, **kwargs) for _ in range(concurrency_factor)]
|
||||
await asyncio.gather(*tasks)
|
||||
concurrent_time = time.perf_counter_ns() - conc_start
|
||||
finally:
|
||||
gc.enable()
|
||||
|
||||
# Output parseable metrics
|
||||
tag = f"{test_module_name}:{test_class_name}:{test_function}:{function_name}:{loop_index}"
|
||||
print(f"!@######CONC:{tag}:{sequential_time}:{concurrent_time}:{concurrency_factor}######@!")
|
||||
|
||||
return result
|
||||
|
||||
return async_wrapper
|
||||
|
|
|
|||
|
|
@ -10,12 +10,20 @@ import sentry_sdk
|
|||
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_temp_dir
|
||||
|
||||
# Known CrossHair limitations that produce invalid Python syntax in generated tests:
|
||||
# - "<locals>" - higher-order functions returning nested functions
|
||||
# - " object at 0x" - objects with default __repr__
|
||||
# - "<list_iterator" - iterator objects
|
||||
CROSSHAIR_KNOWN_LIMITATION_PATTERNS = ("<locals>", " object at 0x", "<list_iterator")
|
||||
|
||||
|
||||
def is_valid_concolic_test(test_code: str, project_root: Optional[str] = None) -> bool:
|
||||
try:
|
||||
ast.parse(test_code)
|
||||
except SyntaxError:
|
||||
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
|
||||
is_known_limitation = any(pattern in test_code for pattern in CROSSHAIR_KNOWN_LIMITATION_PATTERNS)
|
||||
if not is_known_limitation:
|
||||
sentry_sdk.capture_message(f"CrossHair generated test with syntax error:\n{test_code}")
|
||||
return False
|
||||
|
||||
temp_path = (codeflash_temp_dir / f"concolic_test_{uuid.uuid4().hex}.py").resolve()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
|||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
|
||||
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark
|
||||
MAX_TEST_FUNCTION_RUNS = 50
|
||||
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
|
||||
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
|
||||
|
|
|
|||
|
|
@ -1439,9 +1439,12 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
self.added_decorator = False
|
||||
|
||||
# Choose decorator based on mode
|
||||
self.decorator_name = (
|
||||
"codeflash_behavior_async" if mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
self.decorator_name = "codeflash_behavior_async"
|
||||
elif mode == TestingMode.CONCURRENCY:
|
||||
self.decorator_name = "codeflash_concurrency_async"
|
||||
else:
|
||||
self.decorator_name = "codeflash_performance_async"
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
# Track when we enter a class
|
||||
|
|
@ -1484,12 +1487,14 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
|
||||
return decorator_node.func.value in {
|
||||
"codeflash_trace_async",
|
||||
"codeflash_behavior_async",
|
||||
"codeflash_performance_async",
|
||||
"codeflash_concurrency_async",
|
||||
}
|
||||
return False
|
||||
|
||||
|
|
@ -1501,6 +1506,14 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
self.mode = mode
|
||||
self.has_import = False
|
||||
|
||||
def _get_decorator_name(self) -> str:
|
||||
"""Get the decorator name based on the testing mode."""
|
||||
if self.mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_async"
|
||||
if self.mode == TestingMode.CONCURRENCY:
|
||||
return "codeflash_concurrency_async"
|
||||
return "codeflash_performance_async"
|
||||
|
||||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
|
||||
# Check if the async decorator import is already present
|
||||
if (
|
||||
|
|
@ -1512,9 +1525,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
and node.module.attr.value == "codeflash_wrap_decorator"
|
||||
and not isinstance(node.names, cst.ImportStar)
|
||||
):
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
decorator_name = self._get_decorator_name()
|
||||
for import_alias in node.names:
|
||||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
|
|
@ -1525,9 +1536,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
# Choose import based on mode
|
||||
decorator_name = (
|
||||
"codeflash_behavior_async" if self.mode == TestingMode.BEHAVIOR else "codeflash_performance_async"
|
||||
)
|
||||
decorator_name = self._get_decorator_name()
|
||||
|
||||
# Parse the import statement into a CST node
|
||||
import_node = cst.parse_statement(f"from codeflash.code_utils.codeflash_wrap_decorator import {decorator_name}")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import hashlib
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -16,6 +17,7 @@ from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
|||
from codeflash.context.unused_definition_remover import (
|
||||
collect_top_level_defs_with_usages,
|
||||
extract_names_from_targets,
|
||||
get_section_names,
|
||||
remove_unused_definitions_by_function_names,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
|
||||
|
|
@ -29,14 +31,44 @@ from codeflash.models.models import (
|
|||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from libcst import CSTNode
|
||||
|
||||
from codeflash.context.unused_definition_remover import UsageInfo
|
||||
|
||||
|
||||
def build_testgen_context(
|
||||
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool, # noqa: FBT001
|
||||
include_imported_classes: bool, # noqa: FBT001
|
||||
) -> CodeStringsMarkdown:
|
||||
"""Build testgen context with optional imported class definitions and external base inits."""
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=remove_docstrings,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
|
||||
if include_imported_classes:
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
|
||||
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
|
||||
if external_base_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
|
||||
return testgen_context
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
|
|
@ -120,55 +152,37 @@ def get_code_optimization_context(
|
|||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
read_only_context_code = ""
|
||||
|
||||
# Extract code context for testgen
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
# Extract code context for testgen with progressive fallback for token limits
|
||||
# Try in order: full context -> remove docstrings -> remove imported classes
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
include_imported_classes=True,
|
||||
)
|
||||
|
||||
# Extract class definitions for imported types from project modules
|
||||
# This helps the LLM understand class constructors and structure
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
# Merge imported class definitions into testgen context
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
# First try removing docstrings
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context exceeded token limit, removing docstrings")
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
include_imported_classes=True,
|
||||
)
|
||||
# Re-extract imported classes (they may still fit)
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
# If still over limit, try without imported class definitions
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
logger.debug("Testgen context still exceeded token limit, removing imported class definitions")
|
||||
testgen_context = build_testgen_context(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
include_imported_classes=False,
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
|
||||
if encoded_tokens_len(testgen_context.markdown) > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
code_hash_context = hashing_code_context.markdown
|
||||
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
|
||||
|
|
@ -184,114 +198,6 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
|
||||
def extract_code_string_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
|
||||
) -> CodeString:
|
||||
"""Extract code context from files containing target functions and their helpers.
|
||||
This function processes two sets of files:
|
||||
1. Files containing the function to optimize (fto) and their first-degree helpers
|
||||
2. Files containing only helpers of helpers (with no overlap with the first set).
|
||||
|
||||
For each file, it extracts relevant code based on the specified context type, adds necessary
|
||||
imports, and combines them.
|
||||
|
||||
Args:
|
||||
----
|
||||
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
|
||||
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
|
||||
project_root_path: Root path of the project
|
||||
remove_docstrings: Whether to remove docstrings from the extracted code
|
||||
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
|
||||
|
||||
Returns:
|
||||
-------
|
||||
CodeString containing the extracted code context with necessary imports
|
||||
|
||||
""" # noqa: D205
|
||||
# Rearrange to remove overlaps, so we only access each file path once
|
||||
helpers_of_helpers_no_overlap = defaultdict(set)
|
||||
for file_path, function_sources in helpers_of_helpers.items():
|
||||
if file_path in helpers_of_fto:
|
||||
# Remove duplicates within the same file path, in case a helper of helper is also a helper of fto
|
||||
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
|
||||
else:
|
||||
helpers_of_helpers_no_overlap[file_path] = function_sources
|
||||
|
||||
final_code_string_context = ""
|
||||
|
||||
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
|
||||
for file_path, function_sources in helpers_of_fto.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
qualified_function_names = {func.qualified_name for func in function_sources}
|
||||
helpers_of_helpers_qualified_names = {
|
||||
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
|
||||
}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(
|
||||
original_code, qualified_function_names | helpers_of_helpers_qualified_names
|
||||
)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
code_without_unused_defs,
|
||||
code_context_type,
|
||||
qualified_function_names,
|
||||
helpers_of_helpers_qualified_names,
|
||||
remove_docstrings,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
if code_context.strip():
|
||||
final_code_string_context += f"\n{code_context}"
|
||||
final_code_string_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=final_code_string_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=list(helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())),
|
||||
)
|
||||
if code_context_type == CodeContextType.READ_WRITABLE:
|
||||
return CodeString(code=final_code_string_context)
|
||||
# Extract code from file paths containing helpers of helpers
|
||||
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(
|
||||
original_code, qualified_helper_function_names
|
||||
)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
if code_context.strip():
|
||||
final_code_string_context += f"\n{code_context}"
|
||||
final_code_string_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=final_code_string_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
|
||||
)
|
||||
return CodeString(code=final_code_string_context)
|
||||
|
||||
|
||||
def extract_code_markdown_context_from_files(
|
||||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
|
|
@ -526,6 +432,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
the LLM understand the actual class structure (constructors, methods, inheritance)
|
||||
rather than just seeing import statements.
|
||||
|
||||
Also recursively extracts base classes when a class inherits from another class
|
||||
in the same module, ensuring the full inheritance chain is available for
|
||||
understanding constructor signatures.
|
||||
|
||||
Args:
|
||||
code_context: The already extracted code context containing imports
|
||||
project_root_path: Root path of the project
|
||||
|
|
@ -568,6 +478,68 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
|
||||
class_code_strings: list[CodeString] = []
|
||||
|
||||
module_cache: dict[Path, tuple[str, ast.Module]] = {}
|
||||
|
||||
def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None:
|
||||
if module_path in module_cache:
|
||||
return module_cache[module_path]
|
||||
try:
|
||||
module_source = module_path.read_text(encoding="utf-8")
|
||||
module_tree = ast.parse(module_source)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
module_cache[module_path] = (module_source, module_tree)
|
||||
return module_source, module_tree
|
||||
|
||||
def extract_class_and_bases(
|
||||
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module
|
||||
) -> None:
|
||||
"""Extract a class and its base classes recursively from the same module."""
|
||||
# Skip if already extracted
|
||||
if (module_path, class_name) in extracted_classes:
|
||||
return
|
||||
|
||||
# Find the class definition in the module
|
||||
class_node = None
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
class_node = node
|
||||
break
|
||||
|
||||
if class_node is None:
|
||||
return
|
||||
|
||||
# First, recursively extract base classes from the same module
|
||||
for base in class_node.bases:
|
||||
base_name = None
|
||||
if isinstance(base, ast.Name):
|
||||
base_name = base.id
|
||||
elif isinstance(base, ast.Attribute):
|
||||
# For module.ClassName, we skip (cross-module inheritance)
|
||||
continue
|
||||
|
||||
if base_name and base_name not in existing_definitions:
|
||||
# Check if base class is defined in the same module
|
||||
extract_class_and_bases(base_name, module_path, module_source, module_tree)
|
||||
|
||||
# Now extract this class (after its bases, so base classes appear first)
|
||||
if (module_path, class_name) in extracted_classes:
|
||||
return # Already added by another path
|
||||
|
||||
lines = module_source.split("\n")
|
||||
start_line = class_node.lineno
|
||||
if class_node.decorator_list:
|
||||
start_line = min(d.lineno for d in class_node.decorator_list)
|
||||
class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno])
|
||||
|
||||
# Extract imports for the class
|
||||
class_imports = extract_imports_for_class(module_tree, class_node, module_source)
|
||||
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
|
||||
|
||||
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
|
||||
extracted_classes.add((module_path, class_name))
|
||||
|
||||
for name, module_name in imported_names.items():
|
||||
# Skip if already defined in context
|
||||
if name in existing_definitions:
|
||||
|
|
@ -593,28 +565,14 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
if path_belongs_to_site_packages(module_path):
|
||||
continue
|
||||
|
||||
# Skip if we've already extracted this class
|
||||
if (module_path, name) in extracted_classes:
|
||||
# Get module source and tree
|
||||
result = get_module_source_and_tree(module_path)
|
||||
if result is None:
|
||||
continue
|
||||
module_source, module_tree = result
|
||||
|
||||
# Parse the module to find the class definition
|
||||
module_source = module_path.read_text(encoding="utf-8")
|
||||
module_tree = ast.parse(module_source)
|
||||
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
# Extract the class source code
|
||||
lines = module_source.split("\n")
|
||||
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
|
||||
|
||||
# Also extract any necessary imports for the class (base classes, type hints)
|
||||
class_imports = _extract_imports_for_class(module_tree, node, module_source)
|
||||
|
||||
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
|
||||
|
||||
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
|
||||
extracted_classes.add((module_path, name))
|
||||
break
|
||||
# Extract the class and its base classes
|
||||
extract_class_and_bases(name, module_path, module_source, module_tree)
|
||||
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting class definition for {name} from {module_name}")
|
||||
|
|
@ -623,10 +581,111 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
return CodeStringsMarkdown(code_strings=class_code_strings)
|
||||
|
||||
|
||||
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
|
||||
"""Extract __init__ methods from external library base classes.
|
||||
|
||||
Scans the code context for classes that inherit from external libraries and extracts
|
||||
just their __init__ methods. This helps the LLM understand constructor signatures
|
||||
for mocking or instantiation.
|
||||
"""
|
||||
import importlib
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
all_code = "\n".join(cs.code for cs in code_context.code_strings)
|
||||
|
||||
try:
|
||||
tree = ast.parse(all_code)
|
||||
except SyntaxError:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
imported_names: dict[str, str] = {}
|
||||
external_bases: list[tuple[str, str]] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
if alias.name != "*":
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
imported_names[imported_name] = node.module
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
for base in node.bases:
|
||||
base_name = None
|
||||
if isinstance(base, ast.Name):
|
||||
base_name = base.id
|
||||
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
|
||||
base_name = base.attr
|
||||
|
||||
if base_name and base_name in imported_names:
|
||||
module_name = imported_names[base_name]
|
||||
if not _is_project_module(module_name, project_root_path):
|
||||
external_bases.append((base_name, module_name))
|
||||
|
||||
if not external_bases:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
extracted: set[tuple[str, str]] = set()
|
||||
|
||||
for base_name, module_name in external_bases:
|
||||
if (module_name, base_name) in extracted:
|
||||
continue
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
base_class = getattr(module, base_name, None)
|
||||
if base_class is None:
|
||||
continue
|
||||
|
||||
init_method = getattr(base_class, "__init__", None)
|
||||
if init_method is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
class_file = Path(inspect.getfile(base_class))
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
|
||||
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
|
||||
code_strings.append(CodeString(code=class_source, file_path=class_file))
|
||||
extracted.add((module_name, base_name))
|
||||
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
|
||||
"""Check if a module is part of the project (not external/stdlib)."""
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
else:
|
||||
if spec is None or spec.origin is None:
|
||||
return False
|
||||
module_path = Path(spec.origin)
|
||||
# Check if the module is in site-packages (external dependency)
|
||||
# This must be checked first because .venv/site-packages is under project root
|
||||
if path_belongs_to_site_packages(module_path):
|
||||
return False
|
||||
# Check if the module is within the project root
|
||||
return str(module_path).startswith(str(project_root_path) + os.sep)
|
||||
|
||||
|
||||
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
"""Extract import statements needed for a class definition.
|
||||
|
||||
This extracts imports for base classes and commonly used type annotations.
|
||||
This extracts imports for base classes, decorators, and type annotations.
|
||||
"""
|
||||
needed_names: set[str] = set()
|
||||
|
||||
|
|
@ -638,35 +697,139 @@ def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef
|
|||
# For things like abc.ABC, we need the module name
|
||||
needed_names.add(base.value.id)
|
||||
|
||||
# Get decorator names (e.g., dataclass, field)
|
||||
for decorator in class_node.decorator_list:
|
||||
if isinstance(decorator, ast.Name):
|
||||
needed_names.add(decorator.id)
|
||||
elif isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name):
|
||||
needed_names.add(decorator.func.id)
|
||||
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
|
||||
needed_names.add(decorator.func.value.id)
|
||||
|
||||
# Get type annotation names from class body (for dataclass fields)
|
||||
for item in ast.walk(class_node):
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
collect_names_from_annotation(item.annotation, needed_names)
|
||||
# Also check for field() calls which are common in dataclasses
|
||||
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
|
||||
needed_names.add(item.func.id)
|
||||
|
||||
# Find imports that provide these names
|
||||
import_lines: list[str] = []
|
||||
source_lines = module_source.split("\n")
|
||||
added_imports: set[int] = set() # Track line numbers to avoid duplicates
|
||||
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
if name in needed_names:
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name in needed_names:
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
|
||||
return "\n".join(import_lines)
|
||||
|
||||
|
||||
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
|
||||
"""Recursively collect type annotation names from an AST node."""
|
||||
if isinstance(node, ast.Name):
|
||||
names.add(node.id)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
collect_names_from_annotation(node.value, names)
|
||||
collect_names_from_annotation(node.slice, names)
|
||||
elif isinstance(node, ast.Tuple):
|
||||
for elt in node.elts:
|
||||
collect_names_from_annotation(elt, names)
|
||||
elif isinstance(node, ast.BinOp): # For Union types with | syntax
|
||||
collect_names_from_annotation(node.left, names)
|
||||
collect_names_from_annotation(node.right, names)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
|
||||
|
||||
|
||||
def get_section_names(node: cst.CSTNode) -> list[str]:
|
||||
"""Returns the section attribute names (e.g., body, orelse) for a given node if they exist.""" # noqa: D401
|
||||
possible_sections = ["body", "orelse", "finalbody", "handlers"]
|
||||
return [sec for sec in possible_sections if hasattr(node, sec)]
|
||||
class UsedNameCollector(cst.CSTVisitor):
|
||||
"""Collects all base names referenced in code (for import preservation)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.used_names: set[str] = set()
|
||||
self.defined_names: set[str] = set()
|
||||
|
||||
def visit_Name(self, node: cst.Name) -> None:
|
||||
self.used_names.add(node.value)
|
||||
|
||||
def visit_Attribute(self, node: cst.Attribute) -> bool | None:
|
||||
base = node.value
|
||||
while isinstance(base, cst.Attribute):
|
||||
base = base.value
|
||||
if isinstance(base, cst.Name):
|
||||
self.used_names.add(base.value)
|
||||
return True
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
|
||||
self.defined_names.add(node.name.value)
|
||||
return True
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool | None:
|
||||
self.defined_names.add(node.name.value)
|
||||
return True
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> bool | None:
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
self.defined_names.update(names)
|
||||
return True
|
||||
|
||||
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None:
|
||||
names = extract_names_from_targets(node.target)
|
||||
self.defined_names.update(names)
|
||||
return True
|
||||
|
||||
def get_external_names(self) -> set[str]:
|
||||
return self.used_names - self.defined_names - {"self", "cls"}
|
||||
|
||||
|
||||
def get_imported_names(import_node: cst.Import | cst.ImportFrom) -> set[str]:
|
||||
"""Extract the names made available by an import statement."""
|
||||
names: set[str] = set()
|
||||
if isinstance(import_node, cst.Import):
|
||||
if isinstance(import_node.names, cst.ImportStar):
|
||||
return {"*"}
|
||||
for alias in import_node.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
if alias.asname and isinstance(alias.asname.name, cst.Name):
|
||||
names.add(alias.asname.name.value)
|
||||
elif isinstance(alias.name, cst.Name):
|
||||
names.add(alias.name.value)
|
||||
elif isinstance(alias.name, cst.Attribute):
|
||||
# import foo.bar -> accessible as "foo"
|
||||
base: cst.BaseExpression = alias.name
|
||||
while isinstance(base, cst.Attribute):
|
||||
base = base.value
|
||||
if isinstance(base, cst.Name):
|
||||
names.add(base.value)
|
||||
elif isinstance(import_node, cst.ImportFrom):
|
||||
if isinstance(import_node.names, cst.ImportStar):
|
||||
return {"*"}
|
||||
for alias in import_node.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
if alias.asname and isinstance(alias.asname.name, cst.Name):
|
||||
names.add(alias.asname.name.value)
|
||||
elif isinstance(alias.name, cst.Name):
|
||||
names.add(alias.name.value)
|
||||
return names
|
||||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
|
|
@ -693,12 +856,22 @@ def parse_code_and_prune_cst(
|
|||
if code_context_type == CodeContextType.READ_WRITABLE:
|
||||
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
|
||||
elif code_context_type == CodeContextType.READ_ONLY:
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(
|
||||
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
|
||||
filtered_node, found_target = prune_cst_for_context(
|
||||
module,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=False,
|
||||
include_init_dunder=False,
|
||||
)
|
||||
elif code_context_type == CodeContextType.TESTGEN:
|
||||
filtered_node, found_target = prune_cst_for_testgen_code(
|
||||
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
|
||||
filtered_node, found_target = prune_cst_for_context(
|
||||
module,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=True,
|
||||
include_init_dunder=True,
|
||||
)
|
||||
elif code_context_type == CodeContextType.HASHING:
|
||||
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
|
||||
|
|
@ -740,10 +913,29 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
|
|||
# Do not recurse into nested classes
|
||||
if prefix:
|
||||
return None, False
|
||||
|
||||
class_name = node.name.value
|
||||
|
||||
# Assuming always an IndentedBlock
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
|
||||
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
class_prefix = f"{prefix}.{class_name}" if prefix else class_name
|
||||
|
||||
# Check if this class contains any target functions
|
||||
has_target_functions = any(
|
||||
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
|
||||
for stmt in node.body.body
|
||||
)
|
||||
|
||||
# If the class is used as a dependency (not containing target functions), keep it entirely
|
||||
# This handles cases like enums, dataclasses, and other types used by the target function
|
||||
if (
|
||||
not has_target_functions
|
||||
and class_name in defs_with_usages
|
||||
and defs_with_usages[class_name].used_by_qualified_function
|
||||
):
|
||||
return node, True
|
||||
|
||||
new_body = []
|
||||
found_target = False
|
||||
|
||||
|
|
@ -903,17 +1095,29 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def prune_cst_for_read_only_code( # noqa: PLR0911
|
||||
def prune_cst_for_context( # noqa: PLR0911
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str],
|
||||
prefix: str = "",
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
include_target_in_output: bool = False, # noqa: FBT001, FBT002
|
||||
include_init_dunder: bool = False, # noqa: FBT001, FBT002
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for read-only context.
|
||||
"""Recursively filter the node for code context extraction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Args:
|
||||
node: The CST node to filter
|
||||
target_functions: Set of qualified function names that are targets
|
||||
helpers_of_helper_functions: Set of helper function qualified names
|
||||
prefix: Current qualified name prefix (for class methods)
|
||||
remove_docstrings: Whether to remove docstrings from output
|
||||
include_target_in_output: If True, include target functions in output (testgen mode)
|
||||
If False, exclude target functions (read-only mode)
|
||||
include_init_dunder: If True, include __init__ in dunder methods (testgen mode)
|
||||
If False, exclude __init__ from dunder methods (read-only mode)
|
||||
|
||||
Returns:
|
||||
(filtered_node, found_target):
|
||||
filtered_node: The modified CST node or None if it should be removed.
|
||||
found_target: True if a target function was found in this node's subtree.
|
||||
|
|
@ -924,17 +1128,28 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
|
|||
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
# If it's a target function, remove it but mark found_target = True
|
||||
|
||||
# Check if it's a helper of helper function
|
||||
if qualified_name in helpers_of_helper_functions:
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body)), True
|
||||
return node, True
|
||||
|
||||
# Check if it's a target function
|
||||
if qualified_name in target_functions:
|
||||
if include_target_in_output:
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body)), True
|
||||
return node, True
|
||||
return None, True
|
||||
# Keep only dunder methods
|
||||
if is_dunder_method(node.name.value) and node.name.value != "__init__":
|
||||
|
||||
# Check dunder methods
|
||||
# For read-only mode, exclude __init__; for testgen mode, include all dunders
|
||||
if is_dunder_method(node.name.value) and (include_init_dunder or node.name.value != "__init__"):
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), False
|
||||
return node.with_changes(body=remove_docstring_from_body(node.body)), False
|
||||
return node, False
|
||||
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.ClassDef):
|
||||
|
|
@ -951,8 +1166,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
|
|||
found_in_class = False
|
||||
new_class_body: list[CSTNode] = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
filtered, found_target = prune_cst_for_context(
|
||||
stmt,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
class_prefix,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=include_target_in_output,
|
||||
include_init_dunder=include_init_dunder,
|
||||
)
|
||||
found_in_class |= found_target
|
||||
if filtered:
|
||||
|
|
@ -981,8 +1202,14 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
|
|||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
|
||||
filtered, found_target = prune_cst_for_context(
|
||||
child,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
prefix,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=include_target_in_output,
|
||||
include_init_dunder=include_init_dunder,
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
|
|
@ -992,122 +1219,19 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
|
|||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
original_content,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
prefix,
|
||||
remove_docstrings=remove_docstrings,
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if updates:
|
||||
return (node.with_changes(**updates), found_any_target)
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def prune_cst_for_testgen_code( # noqa: PLR0911
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str],
|
||||
prefix: str = "",
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for testgen context.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(filtered_node, found_target):
|
||||
filtered_node: The modified CST node or None if it should be removed.
|
||||
found_target: True if a target function was found in this node's subtree.
|
||||
|
||||
"""
|
||||
if isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
# If it's a target function, remove it but mark found_target = True
|
||||
if qualified_name in helpers_of_helper_functions or qualified_name in target_functions:
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), True
|
||||
return node, True
|
||||
# Keep all dunder methods
|
||||
if is_dunder_method(node.name.value):
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), False
|
||||
return node, False
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.ClassDef):
|
||||
# Do not recurse into nested classes
|
||||
if prefix:
|
||||
return None, False
|
||||
# Assuming always an IndentedBlock
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
|
||||
|
||||
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
|
||||
# First pass: detect if there is a target function in the class
|
||||
found_in_class = False
|
||||
new_class_body: list[CSTNode] = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_testgen_code(
|
||||
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_in_class |= found_target
|
||||
if filtered:
|
||||
new_class_body.append(filtered)
|
||||
|
||||
if not found_in_class:
|
||||
return None, False
|
||||
|
||||
if remove_docstrings:
|
||||
return node.with_changes(
|
||||
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
|
||||
) if new_class_body else None, True
|
||||
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
|
||||
|
||||
# For other nodes, keep the node and recursively filter children
|
||||
section_names = get_section_names(node)
|
||||
if not section_names:
|
||||
return node, False
|
||||
|
||||
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
|
||||
found_any_target = False
|
||||
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_testgen_code(
|
||||
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_target |= found_target
|
||||
|
||||
if section_found_target or new_children:
|
||||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_testgen_code(
|
||||
filtered, found_target = prune_cst_for_context(
|
||||
original_content,
|
||||
target_functions,
|
||||
helpers_of_helper_functions,
|
||||
prefix,
|
||||
remove_docstrings=remove_docstrings,
|
||||
include_target_in_output=include_target_in_output,
|
||||
include_init_dunder=include_init_dunder,
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
|
||||
if updates:
|
||||
return (node.with_changes(**updates), found_any_target)
|
||||
|
||||
|
|
|
|||
|
|
@ -295,11 +295,18 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
return
|
||||
|
||||
if name in self.definitions and name != self.current_top_level_name:
|
||||
# skip if we are refrencing a class attribute and not a top-level definition
|
||||
# Skip if this Name is the .attr part of an Attribute (e.g., 'x' in 'self.x')
|
||||
# We only want to track the base/value of attribute access, not the attribute name itself
|
||||
if self.class_depth > 0:
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
if parent is not None and isinstance(parent, cst.Attribute):
|
||||
return
|
||||
# Check if this Name is the .attr (property name), not the .value (base)
|
||||
# If it's the .attr, skip it - attribute names aren't references to definitions
|
||||
if parent.attr is node:
|
||||
return
|
||||
# If it's the .value (base), only skip if it's self/cls
|
||||
if name in ("self", "cls"):
|
||||
return
|
||||
self.definitions[self.current_top_level_name].dependencies.add(name)
|
||||
|
||||
|
||||
|
|
@ -553,16 +560,6 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
|
|||
return code
|
||||
|
||||
|
||||
def print_definitions(definitions: dict[str, UsageInfo]) -> None:
|
||||
"""Print information about each definition without the complex node object, used for debugging."""
|
||||
print(f"Found {len(definitions)} definitions:")
|
||||
for name, info in sorted(definitions.items()):
|
||||
print(f" - Name: {name}")
|
||||
print(f" Used by qualified function: {info.used_by_qualified_function}")
|
||||
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
|
||||
print()
|
||||
|
||||
|
||||
def revert_unused_helper_functions(
|
||||
project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
|
||||
) -> None:
|
||||
|
|
@ -637,43 +634,40 @@ def _analyze_imports_in_optimized_code(
|
|||
func_name = helper.only_function_name
|
||||
module_name = helper.file_path.stem
|
||||
# Cache function lookup for this (module, func)
|
||||
file_entry = helpers_by_file_and_func[module_name]
|
||||
if func_name in file_entry:
|
||||
file_entry[func_name].append(helper)
|
||||
else:
|
||||
file_entry[func_name] = [helper]
|
||||
helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper)
|
||||
helpers_by_file[module_name].append(helper)
|
||||
|
||||
# Optimize attribute lookups and method binding outside the loop
|
||||
helpers_by_file_and_func_get = helpers_by_file_and_func.get
|
||||
helpers_by_file_get = helpers_by_file.get
|
||||
|
||||
for node in ast.walk(optimized_ast):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
# Handle "from module import function" statements
|
||||
module_name = node.module
|
||||
if module_name:
|
||||
file_entry = helpers_by_file_and_func_get(module_name, None)
|
||||
file_entry = helpers_by_file_and_func.get(module_name)
|
||||
if file_entry:
|
||||
for alias in node.names:
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
original_name = alias.name
|
||||
helpers = file_entry.get(original_name, None)
|
||||
helpers = file_entry.get(original_name)
|
||||
if helpers:
|
||||
imported_set = imported_names_map[imported_name]
|
||||
for helper in helpers:
|
||||
imported_names_map[imported_name].add(helper.qualified_name)
|
||||
imported_names_map[imported_name].add(helper.fully_qualified_name)
|
||||
imported_set.add(helper.qualified_name)
|
||||
imported_set.add(helper.fully_qualified_name)
|
||||
|
||||
elif isinstance(node, ast.Import):
|
||||
# Handle "import module" statements
|
||||
for alias in node.names:
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
module_name = alias.name
|
||||
for helper in helpers_by_file_get(module_name, []):
|
||||
# For "import module" statements, functions would be called as module.function
|
||||
full_call = f"{imported_name}.{helper.only_function_name}"
|
||||
imported_names_map[full_call].add(helper.qualified_name)
|
||||
imported_names_map[full_call].add(helper.fully_qualified_name)
|
||||
helpers = helpers_by_file.get(module_name)
|
||||
if helpers:
|
||||
imported_set = imported_names_map[f"{imported_name}.{{func}}"]
|
||||
for helper in helpers:
|
||||
# For "import module" statements, functions would be called as module.function
|
||||
full_call = f"{imported_name}.{helper.only_function_name}"
|
||||
full_call_set = imported_names_map[full_call]
|
||||
full_call_set.add(helper.qualified_name)
|
||||
full_call_set.add(helper.fully_qualified_name)
|
||||
|
||||
return dict(imported_names_map)
|
||||
|
||||
|
|
@ -753,27 +747,31 @@ def detect_unused_helper_functions(
|
|||
called_name = node.func.id
|
||||
called_function_names.add(called_name)
|
||||
# Also add the qualified name if this is an imported function
|
||||
if called_name in imported_names_map:
|
||||
called_function_names.update(imported_names_map[called_name])
|
||||
mapped_names = imported_names_map.get(called_name)
|
||||
if mapped_names:
|
||||
called_function_names.update(mapped_names)
|
||||
elif isinstance(node.func, ast.Attribute):
|
||||
# Method call: obj.method() or self.method() or module.function()
|
||||
if isinstance(node.func.value, ast.Name):
|
||||
if node.func.value.id == "self":
|
||||
attr_name = node.func.attr
|
||||
value_id = node.func.value.id
|
||||
if value_id == "self":
|
||||
# self.method_name() -> add both method_name and ClassName.method_name
|
||||
called_function_names.add(node.func.attr)
|
||||
called_function_names.add(attr_name)
|
||||
# For class methods, also add the qualified name
|
||||
# For class methods, also add the qualified name
|
||||
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
|
||||
class_name = function_to_optimize.parents[0].name
|
||||
called_function_names.add(f"{class_name}.{node.func.attr}")
|
||||
called_function_names.add(f"{class_name}.{attr_name}")
|
||||
else:
|
||||
# obj.method() or module.function()
|
||||
attr_name = node.func.attr
|
||||
called_function_names.add(attr_name)
|
||||
called_function_names.add(f"{node.func.value.id}.{attr_name}")
|
||||
full_call = f"{value_id}.{attr_name}"
|
||||
called_function_names.add(full_call)
|
||||
# Check if this is a module.function call that maps to a helper
|
||||
full_call = f"{node.func.value.id}.{attr_name}"
|
||||
if full_call in imported_names_map:
|
||||
called_function_names.update(imported_names_map[full_call])
|
||||
mapped_names = imported_names_map.get(full_call)
|
||||
if mapped_names:
|
||||
called_function_names.update(mapped_names)
|
||||
# Handle nested attribute access like obj.attr.method()
|
||||
# Handle nested attribute access like obj.attr.method()
|
||||
else:
|
||||
called_function_names.add(node.func.attr)
|
||||
|
|
@ -783,6 +781,7 @@ def detect_unused_helper_functions(
|
|||
|
||||
# Find helper functions that are no longer called
|
||||
unused_helpers = []
|
||||
entrypoint_file_path = function_to_optimize.file_path
|
||||
for helper_function in code_context.helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
# Check if the helper function is called using multiple name variants
|
||||
|
|
@ -790,29 +789,30 @@ def detect_unused_helper_functions(
|
|||
helper_simple_name = helper_function.only_function_name
|
||||
helper_fully_qualified_name = helper_function.fully_qualified_name
|
||||
|
||||
# Create a set of all possible names this helper might be called by
|
||||
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
|
||||
|
||||
# Check membership efficiently - exit early on first match
|
||||
if (
|
||||
helper_qualified_name in called_function_names
|
||||
or helper_simple_name in called_function_names
|
||||
or helper_fully_qualified_name in called_function_names
|
||||
):
|
||||
is_called = True
|
||||
# For cross-file helpers, also consider module-based calls
|
||||
if helper_function.file_path != function_to_optimize.file_path:
|
||||
elif helper_function.file_path != entrypoint_file_path:
|
||||
# Add potential module.function combinations
|
||||
module_name = helper_function.file_path.stem
|
||||
possible_call_names.add(f"{module_name}.{helper_simple_name}")
|
||||
|
||||
# Check if any of the possible names are in the called functions
|
||||
is_called = bool(possible_call_names.intersection(called_function_names))
|
||||
module_call = f"{module_name}.{helper_simple_name}"
|
||||
is_called = module_call in called_function_names
|
||||
else:
|
||||
is_called = False
|
||||
|
||||
if not is_called:
|
||||
unused_helpers.append(helper_function)
|
||||
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
|
||||
logger.debug(f" Checked names: {possible_call_names}")
|
||||
else:
|
||||
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
|
||||
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
|
||||
|
||||
ret_val = unused_helpers
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error detecting unused helper functions: {e}")
|
||||
ret_val = []
|
||||
return ret_val
|
||||
return []
|
||||
else:
|
||||
return unused_helpers
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ class BestOptimization(BaseModel):
|
|||
winning_replay_benchmarking_test_results: Optional[TestResults] = None
|
||||
line_profiler_test_results: dict
|
||||
async_throughput: Optional[int] = None
|
||||
concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -172,6 +173,14 @@ class BenchmarkKey:
|
|||
return f"{self.module_path}::{self.function_name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyMetrics:
|
||||
sequential_time_ns: int
|
||||
concurrent_time_ns: int
|
||||
concurrency_factor: int
|
||||
concurrency_ratio: float # sequential_time / concurrent_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDetail:
|
||||
benchmark_name: str
|
||||
|
|
@ -336,6 +345,7 @@ class OptimizedCandidateResult(BaseModel):
|
|||
optimization_candidate_index: int
|
||||
total_candidate_timing: int
|
||||
async_throughput: Optional[int] = None
|
||||
concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
|
||||
|
||||
class GeneratedTests(BaseModel):
|
||||
|
|
@ -557,6 +567,7 @@ class OriginalCodeBaseline(BaseModel):
|
|||
runtime: int
|
||||
coverage_results: Optional[CoverageData]
|
||||
async_throughput: Optional[int] = None
|
||||
concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
|
||||
|
||||
class CoverageStatus(Enum):
|
||||
|
|
@ -648,6 +659,7 @@ class TestingMode(enum.Enum):
|
|||
BEHAVIOR = "behavior"
|
||||
PERFORMANCE = "performance"
|
||||
LINE_PROFILE = "line_profile"
|
||||
CONCURRENCY = "concurrency"
|
||||
|
||||
|
||||
# TODO this class is duplicated in codeflash_capture
|
||||
|
|
|
|||
|
|
@ -100,7 +100,9 @@ from codeflash.models.models import (
|
|||
)
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import (
|
||||
concurrency_gain,
|
||||
coverage_critic,
|
||||
get_acceptance_reason,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
|
|
@ -112,7 +114,11 @@ from codeflash.verification.concolic_testing import generate_concolic_tests
|
|||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results, parse_test_results
|
||||
from codeflash.verification.parse_test_output import (
|
||||
calculate_function_throughput_from_test_results,
|
||||
parse_concurrency_metrics,
|
||||
parse_test_results,
|
||||
)
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
|
@ -125,6 +131,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.models.models import (
|
||||
BenchmarkKey,
|
||||
CodeStringsMarkdown,
|
||||
ConcurrencyMetrics,
|
||||
CoverageData,
|
||||
FunctionCalledInTest,
|
||||
FunctionSource,
|
||||
|
|
@ -603,7 +610,9 @@ class FunctionOptimizer:
|
|||
):
|
||||
console.rule()
|
||||
new_code_context = code_context
|
||||
if self.is_numerical_code: # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
|
||||
if (
|
||||
self.is_numerical_code and not self.args.no_jit_opts
|
||||
): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
|
||||
jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code(
|
||||
code_context.read_writable_code.markdown, self.function_trace_id
|
||||
)
|
||||
|
|
@ -632,7 +641,7 @@ class FunctionOptimizer:
|
|||
read_writable_code=code_context.read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context_code,
|
||||
run_experiment=should_run_experiment,
|
||||
is_numerical_code=self.is_numerical_code,
|
||||
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
|
||||
)
|
||||
|
||||
concurrent.futures.wait([future_tests, future_optimizations])
|
||||
|
|
@ -735,6 +744,13 @@ class FunctionOptimizer:
|
|||
tree.add(f"Optimized async throughput: {candidate_result.async_throughput} executions")
|
||||
tree.add(f"Throughput improvement: {throughput_gain_value * 100:.1f}%")
|
||||
tree.add(f"Throughput ratio: {throughput_gain_value + 1:.3f}X")
|
||||
|
||||
# Display concurrency metrics if available
|
||||
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
|
||||
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
|
||||
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
|
||||
else:
|
||||
tree.add("This candidate is faster than the original code. 🚀")
|
||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||
|
|
@ -753,6 +769,14 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Async throughput: {candidate_result.async_throughput} executions")
|
||||
tree.add(f"Throughput change: {throughput_gain_value * 100:.1f}%")
|
||||
|
||||
# Display concurrency metrics if available
|
||||
if candidate_result.concurrency_metrics and original_code_baseline.concurrency_metrics:
|
||||
orig_ratio = original_code_baseline.concurrency_metrics.concurrency_ratio
|
||||
cand_ratio = candidate_result.concurrency_metrics.concurrency_ratio
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
tree.add(f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)")
|
||||
|
||||
tree.add(
|
||||
f"(Runtime for reference: {humanize_runtime(candidate_result.best_test_runtime)} over "
|
||||
f"{candidate_result.max_loop_count} loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||
|
|
@ -819,6 +843,7 @@ class FunctionOptimizer:
|
|||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
async_throughput=candidate_result.async_throughput,
|
||||
concurrency_metrics=candidate_result.concurrency_metrics,
|
||||
)
|
||||
|
||||
return best_optimization, benchmark_tree
|
||||
|
|
@ -863,6 +888,7 @@ class FunctionOptimizer:
|
|||
winning_benchmarking_test_results=valid_opt.winning_benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=valid_opt.winning_replay_benchmarking_test_results,
|
||||
async_throughput=valid_opt.async_throughput,
|
||||
concurrency_metrics=valid_opt.concurrency_metrics,
|
||||
)
|
||||
valid_candidates_with_shorter_code.append(new_best_opt)
|
||||
diff_lens_list.append(
|
||||
|
|
@ -1014,6 +1040,8 @@ class FunctionOptimizer:
|
|||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
) and quantity_of_tests_critic(candidate_result)
|
||||
|
||||
tree = self.build_runtime_info_tree(
|
||||
|
|
@ -1132,7 +1160,7 @@ class FunctionOptimizer:
|
|||
)
|
||||
if self.experiment_id
|
||||
else None,
|
||||
is_numerical_code=self.is_numerical_code,
|
||||
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
|
||||
)
|
||||
|
||||
processor = CandidateProcessor(
|
||||
|
|
@ -1776,6 +1804,14 @@ class FunctionOptimizer:
|
|||
fto_benchmark_timings=self.function_benchmark_timings,
|
||||
total_benchmark_timings=self.total_benchmark_timings,
|
||||
)
|
||||
acceptance_reason = get_acceptance_reason(
|
||||
original_runtime_ns=original_code_baseline.runtime,
|
||||
optimized_runtime_ns=best_optimization.runtime,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
optimized_async_throughput=best_optimization.async_throughput,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
|
||||
)
|
||||
explanation = Explanation(
|
||||
raw_explanation_message=best_optimization.candidate.explanation,
|
||||
winning_behavior_test_results=best_optimization.winning_behavior_test_results,
|
||||
|
|
@ -1787,6 +1823,9 @@ class FunctionOptimizer:
|
|||
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
|
||||
original_async_throughput=original_code_baseline.async_throughput,
|
||||
best_async_throughput=best_optimization.async_throughput,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
best_concurrency_metrics=best_optimization.concurrency_metrics,
|
||||
acceptance_reason=acceptance_reason,
|
||||
)
|
||||
|
||||
self.replace_function_and_helpers_with_optimized_code(
|
||||
|
|
@ -1884,6 +1923,9 @@ class FunctionOptimizer:
|
|||
original_throughput_str = None
|
||||
optimized_throughput_str = None
|
||||
throughput_improvement_str = None
|
||||
original_concurrency_ratio_str = None
|
||||
optimized_concurrency_ratio_str = None
|
||||
concurrency_improvement_str = None
|
||||
|
||||
if (
|
||||
self.function_to_optimize.is_async
|
||||
|
|
@ -1898,6 +1940,14 @@ class FunctionOptimizer:
|
|||
)
|
||||
throughput_improvement_str = f"{throughput_improvement_value * 100:.1f}%"
|
||||
|
||||
if original_code_baseline.concurrency_metrics is not None and best_optimization.concurrency_metrics is not None:
|
||||
original_concurrency_ratio_str = f"{original_code_baseline.concurrency_metrics.concurrency_ratio:.2f}x"
|
||||
optimized_concurrency_ratio_str = f"{best_optimization.concurrency_metrics.concurrency_ratio:.2f}x"
|
||||
conc_improvement_value = concurrency_gain(
|
||||
original_code_baseline.concurrency_metrics, best_optimization.concurrency_metrics
|
||||
)
|
||||
concurrency_improvement_str = f"{conc_improvement_value * 100:.1f}%"
|
||||
|
||||
new_explanation_raw_str = self.aiservice_client.get_new_explanation(
|
||||
source_code=code_context.read_writable_code.flat,
|
||||
dependency_code=code_context.read_only_context_code,
|
||||
|
|
@ -1915,6 +1965,10 @@ class FunctionOptimizer:
|
|||
optimized_throughput=optimized_throughput_str,
|
||||
throughput_improvement=throughput_improvement_str,
|
||||
function_references=function_references,
|
||||
acceptance_reason=explanation.acceptance_reason.value,
|
||||
original_concurrency_ratio=original_concurrency_ratio_str,
|
||||
optimized_concurrency_ratio=optimized_concurrency_ratio_str,
|
||||
concurrency_improvement=concurrency_improvement_str,
|
||||
)
|
||||
new_explanation = Explanation(
|
||||
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
|
||||
|
|
@ -1927,6 +1981,9 @@ class FunctionOptimizer:
|
|||
benchmark_details=explanation.benchmark_details,
|
||||
original_async_throughput=explanation.original_async_throughput,
|
||||
best_async_throughput=explanation.best_async_throughput,
|
||||
original_concurrency_metrics=explanation.original_concurrency_metrics,
|
||||
best_concurrency_metrics=explanation.best_concurrency_metrics,
|
||||
acceptance_reason=explanation.acceptance_reason,
|
||||
)
|
||||
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
|
||||
|
||||
|
|
@ -2155,12 +2212,22 @@ class FunctionOptimizer:
|
|||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
|
||||
async_throughput = None
|
||||
concurrency_metrics = None
|
||||
if self.function_to_optimize.is_async:
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
|
||||
|
||||
concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
|
||||
)
|
||||
if concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Original concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
|
||||
if self.args.benchmark:
|
||||
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -2175,6 +2242,7 @@ class FunctionOptimizer:
|
|||
coverage_results=coverage_results,
|
||||
line_profile_results=line_profile_results,
|
||||
async_throughput=async_throughput,
|
||||
concurrency_metrics=concurrency_metrics,
|
||||
),
|
||||
functions_to_remove,
|
||||
)
|
||||
|
|
@ -2341,12 +2409,23 @@ class FunctionOptimizer:
|
|||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
|
||||
candidate_async_throughput = None
|
||||
candidate_concurrency_metrics = None
|
||||
if self.function_to_optimize.is_async:
|
||||
candidate_async_throughput = calculate_function_throughput_from_test_results(
|
||||
candidate_benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
|
||||
|
||||
# Run concurrency benchmark for candidate
|
||||
candidate_concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=candidate_helper_code, test_env=test_env
|
||||
)
|
||||
if candidate_concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Candidate concurrency metrics: ratio={candidate_concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={candidate_concurrency_metrics.sequential_time_ns}ns, conc={candidate_concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
|
||||
if self.args.benchmark:
|
||||
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
|
||||
self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root
|
||||
|
|
@ -2367,6 +2446,7 @@ class FunctionOptimizer:
|
|||
optimization_candidate_index=optimization_candidate_index,
|
||||
total_candidate_timing=total_candidate_timing,
|
||||
async_throughput=candidate_async_throughput,
|
||||
concurrency_metrics=candidate_concurrency_metrics,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -2572,3 +2652,57 @@ class FunctionOptimizer:
|
|||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
return line_profile_results
|
||||
|
||||
def run_concurrency_benchmark(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]
|
||||
) -> ConcurrencyMetrics | None:
|
||||
"""Run concurrency benchmark to measure sequential vs concurrent execution for async functions.
|
||||
|
||||
This benchmark detects blocking vs non-blocking async code by comparing:
|
||||
- Sequential execution time (running N iterations one after another)
|
||||
- Concurrent execution time (running N iterations in parallel with asyncio.gather)
|
||||
|
||||
Blocking code (like time.sleep) will have similar sequential and concurrent times.
|
||||
Non-blocking code (like asyncio.sleep) will be much faster when run concurrently.
|
||||
|
||||
Returns:
|
||||
ConcurrencyMetrics if benchmark ran successfully, None otherwise.
|
||||
|
||||
"""
|
||||
if not self.function_to_optimize.is_async:
|
||||
return None
|
||||
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
try:
|
||||
# Add concurrency decorator to the source function
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.CONCURRENCY
|
||||
)
|
||||
|
||||
# Run the concurrency benchmark tests
|
||||
concurrency_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE, # Use performance mode for running
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=5.0, # Short benchmark time
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=3,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Concurrency benchmark failed: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Restore original code
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
|
||||
# Parse concurrency metrics from stdout
|
||||
if concurrency_results and concurrency_results.perf_stdout:
|
||||
return parse_concurrency_metrics(concurrency_results, self.function_to_optimize.function_name)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.config_consts import (
|
||||
COVERAGE_THRESHOLD,
|
||||
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD,
|
||||
MIN_IMPROVEMENT_THRESHOLD,
|
||||
MIN_TESTCASE_PASSED_THRESHOLD,
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
|
||||
|
|
@ -12,7 +14,14 @@ from codeflash.code_utils.config_consts import (
|
|||
from codeflash.models import models
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
from codeflash.models.models import ConcurrencyMetrics, CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
|
||||
|
||||
class AcceptanceReason(Enum):
|
||||
RUNTIME = "runtime"
|
||||
THROUGHPUT = "throughput"
|
||||
CONCURRENCY = "concurrency"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
|
||||
|
|
@ -36,6 +45,22 @@ def throughput_gain(*, original_throughput: int, optimized_throughput: int) -> f
|
|||
return (optimized_throughput - original_throughput) / original_throughput
|
||||
|
||||
|
||||
def concurrency_gain(original_metrics: ConcurrencyMetrics, optimized_metrics: ConcurrencyMetrics) -> float:
|
||||
"""Calculate concurrency ratio improvement.
|
||||
|
||||
Returns the relative improvement in concurrency ratio.
|
||||
Higher is better - means the optimized code scales better with concurrent execution.
|
||||
|
||||
concurrency_ratio = sequential_time / concurrent_time
|
||||
A ratio of 10 means concurrent execution is 10x faster than sequential.
|
||||
"""
|
||||
if original_metrics.concurrency_ratio == 0:
|
||||
return 0.0
|
||||
return (
|
||||
optimized_metrics.concurrency_ratio - original_metrics.concurrency_ratio
|
||||
) / original_metrics.concurrency_ratio
|
||||
|
||||
|
||||
def speedup_critic(
|
||||
candidate_result: OptimizedCandidateResult,
|
||||
original_code_runtime: int,
|
||||
|
|
@ -44,10 +69,12 @@ def speedup_critic(
|
|||
disable_gh_action_noise: bool = False,
|
||||
original_async_throughput: int | None = None,
|
||||
best_throughput_until_now: int | None = None,
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
best_concurrency_ratio_until_now: float | None = None,
|
||||
) -> bool:
|
||||
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
|
||||
|
||||
Evaluates both runtime performance and async throughput improvements.
|
||||
Evaluates runtime performance, async throughput, and concurrency improvements.
|
||||
|
||||
For runtime performance:
|
||||
- Ensures the optimization is actually faster than the original code, above the noise floor.
|
||||
|
|
@ -58,6 +85,10 @@ def speedup_critic(
|
|||
For async throughput (when available):
|
||||
- Evaluates throughput improvements using MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
- Throughput improvements complement runtime improvements for async functions
|
||||
|
||||
For concurrency (when available):
|
||||
- Evaluates concurrency ratio improvements using MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
|
||||
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
|
||||
"""
|
||||
# Runtime performance evaluation
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
|
|
@ -86,14 +117,78 @@ def speedup_critic(
|
|||
best_throughput_until_now is None or candidate_result.async_throughput > best_throughput_until_now
|
||||
)
|
||||
|
||||
# Concurrency evaluation
|
||||
concurrency_improved = False
|
||||
concurrency_is_best = True
|
||||
if original_concurrency_metrics is not None and candidate_result.concurrency_metrics is not None:
|
||||
conc_gain = concurrency_gain(original_concurrency_metrics, candidate_result.concurrency_metrics)
|
||||
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
|
||||
concurrency_is_best = (
|
||||
best_concurrency_ratio_until_now is None
|
||||
or candidate_result.concurrency_metrics.concurrency_ratio > best_concurrency_ratio_until_now
|
||||
)
|
||||
|
||||
# Accept if ANY of: runtime, throughput, or concurrency improves significantly
|
||||
if original_async_throughput is not None and candidate_result.async_throughput is not None:
|
||||
# When throughput data is available, accept if EITHER throughput OR runtime improves significantly
|
||||
throughput_acceptance = throughput_improved and throughput_is_best
|
||||
runtime_acceptance = runtime_improved and runtime_is_best
|
||||
return throughput_acceptance or runtime_acceptance
|
||||
concurrency_acceptance = concurrency_improved and concurrency_is_best
|
||||
return throughput_acceptance or runtime_acceptance or concurrency_acceptance
|
||||
return runtime_improved and runtime_is_best
|
||||
|
||||
|
||||
def get_acceptance_reason(
|
||||
original_runtime_ns: int,
|
||||
optimized_runtime_ns: int,
|
||||
*,
|
||||
original_async_throughput: int | None = None,
|
||||
optimized_async_throughput: int | None = None,
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
) -> AcceptanceReason:
|
||||
"""Determine why an optimization was accepted.
|
||||
|
||||
Returns the primary reason for acceptance, with priority:
|
||||
concurrency > throughput > runtime (for async code).
|
||||
"""
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
if env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2
|
||||
|
||||
perf_gain = performance_gain(original_runtime_ns=original_runtime_ns, optimized_runtime_ns=optimized_runtime_ns)
|
||||
runtime_improved = perf_gain > noise_floor
|
||||
|
||||
throughput_improved = False
|
||||
if (
|
||||
original_async_throughput is not None
|
||||
and optimized_async_throughput is not None
|
||||
and original_async_throughput > 0
|
||||
):
|
||||
throughput_gain_value = throughput_gain(
|
||||
original_throughput=original_async_throughput, optimized_throughput=optimized_async_throughput
|
||||
)
|
||||
throughput_improved = throughput_gain_value > MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD
|
||||
|
||||
concurrency_improved = False
|
||||
if original_concurrency_metrics is not None and optimized_concurrency_metrics is not None:
|
||||
conc_gain = concurrency_gain(original_concurrency_metrics, optimized_concurrency_metrics)
|
||||
concurrency_improved = conc_gain > MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD
|
||||
|
||||
# Return reason with priority: concurrency > throughput > runtime
|
||||
if original_async_throughput is not None and optimized_async_throughput is not None:
|
||||
if concurrency_improved:
|
||||
return AcceptanceReason.CONCURRENCY
|
||||
if throughput_improved:
|
||||
return AcceptanceReason.THROUGHPUT
|
||||
if runtime_improved:
|
||||
return AcceptanceReason.RUNTIME
|
||||
return AcceptanceReason.NONE
|
||||
|
||||
if runtime_improved:
|
||||
return AcceptanceReason.RUNTIME
|
||||
return AcceptanceReason.NONE
|
||||
|
||||
|
||||
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
|
||||
test_results = candidate_result.behavior_test_results
|
||||
report = test_results.get_test_pass_fail_report_by_type()
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from rich.table import Table
|
|||
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import BenchmarkDetail, TestResults
|
||||
from codeflash.result.critic import throughput_gain
|
||||
from codeflash.models.models import BenchmarkDetail, ConcurrencyMetrics, TestResults
|
||||
from codeflash.result.critic import AcceptanceReason, concurrency_gain, throughput_gain
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
|
|
@ -27,31 +27,44 @@ class Explanation:
|
|||
benchmark_details: Optional[list[BenchmarkDetail]] = None
|
||||
original_async_throughput: Optional[int] = None
|
||||
best_async_throughput: Optional[int] = None
|
||||
original_concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
best_concurrency_metrics: Optional[ConcurrencyMetrics] = None
|
||||
acceptance_reason: AcceptanceReason = AcceptanceReason.RUNTIME
|
||||
|
||||
@property
|
||||
def perf_improvement_line(self) -> str:
|
||||
# speedup property already handles choosing between runtime and throughput
|
||||
improvement_type = {
|
||||
AcceptanceReason.RUNTIME: "runtime",
|
||||
AcceptanceReason.THROUGHPUT: "throughput",
|
||||
AcceptanceReason.CONCURRENCY: "concurrency",
|
||||
AcceptanceReason.NONE: "",
|
||||
}.get(self.acceptance_reason, "")
|
||||
|
||||
if improvement_type:
|
||||
return f"{self.speedup_pct} {improvement_type} improvement ({self.speedup_x} faster)."
|
||||
return f"{self.speedup_pct} improvement ({self.speedup_x} faster)."
|
||||
|
||||
@property
|
||||
def speedup(self) -> float:
|
||||
runtime_improvement = (self.original_runtime_ns / self.best_runtime_ns) - 1
|
||||
|
||||
# Use throughput improvement if we have async metrics and throughput is better
|
||||
"""Returns the improvement value for the metric that caused acceptance."""
|
||||
if (
|
||||
self.original_async_throughput is not None
|
||||
self.acceptance_reason == AcceptanceReason.CONCURRENCY
|
||||
and self.original_concurrency_metrics
|
||||
and self.best_concurrency_metrics
|
||||
):
|
||||
return concurrency_gain(self.original_concurrency_metrics, self.best_concurrency_metrics)
|
||||
|
||||
if (
|
||||
self.acceptance_reason == AcceptanceReason.THROUGHPUT
|
||||
and self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
and self.original_async_throughput > 0
|
||||
):
|
||||
throughput_improvement = throughput_gain(
|
||||
return throughput_gain(
|
||||
original_throughput=self.original_async_throughput, optimized_throughput=self.best_async_throughput
|
||||
)
|
||||
|
||||
# Use throughput metrics if throughput improvement is better or runtime got worse
|
||||
if throughput_improvement > runtime_improvement or runtime_improvement <= 0:
|
||||
return throughput_improvement
|
||||
|
||||
return runtime_improvement
|
||||
return (self.original_runtime_ns / self.best_runtime_ns) - 1
|
||||
|
||||
@property
|
||||
def speedup_x(self) -> str:
|
||||
|
|
@ -108,7 +121,22 @@ class Explanation:
|
|||
console.print(table)
|
||||
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
|
||||
|
||||
if self.original_async_throughput is not None and self.best_async_throughput is not None:
|
||||
if (
|
||||
self.acceptance_reason == AcceptanceReason.CONCURRENCY
|
||||
and self.original_concurrency_metrics
|
||||
and self.best_concurrency_metrics
|
||||
):
|
||||
orig_ratio = self.original_concurrency_metrics.concurrency_ratio
|
||||
best_ratio = self.best_concurrency_metrics.concurrency_ratio
|
||||
performance_description = (
|
||||
f"Concurrency ratio improved from {orig_ratio:.2f}x to {best_ratio:.2f}x "
|
||||
f"(concurrent execution now {best_ratio:.2f}x faster than sequential)\n\n"
|
||||
)
|
||||
elif (
|
||||
self.acceptance_reason == AcceptanceReason.THROUGHPUT
|
||||
and self.original_async_throughput is not None
|
||||
and self.best_async_throughput is not None
|
||||
):
|
||||
performance_description = (
|
||||
f"Throughput improved from {self.original_async_throughput} to {self.best_async_throughput} operations/second "
|
||||
f"(runtime: {original_runtime_human} → {best_runtime_human})\n\n"
|
||||
|
|
|
|||
|
|
@ -138,6 +138,13 @@ def main(args: Namespace | None = None) -> ArgumentParser:
|
|||
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
|
||||
else:
|
||||
env["PYTHONPATH"] = project_root_str
|
||||
# Disable JIT compilation to ensure tracing captures all function calls
|
||||
env["NUMBA_DISABLE_JIT"] = str(1)
|
||||
env["TORCHDYNAMO_DISABLE"] = str(1)
|
||||
env["PYTORCH_JIT"] = str(0)
|
||||
env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||
env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
|
||||
env["JAX_DISABLE_JIT"] = str(1)
|
||||
processes.append(
|
||||
subprocess.Popen(
|
||||
[
|
||||
|
|
@ -175,6 +182,13 @@ def main(args: Namespace | None = None) -> ArgumentParser:
|
|||
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
|
||||
else:
|
||||
env["PYTHONPATH"] = project_root_str
|
||||
# Disable JIT compilation to ensure tracing captures all function calls
|
||||
env["NUMBA_DISABLE_JIT"] = str(1)
|
||||
env["TORCHDYNAMO_DISABLE"] = str(1)
|
||||
env["PYTORCH_JIT"] = str(0)
|
||||
env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||
env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
|
||||
env["JAX_DISABLE_JIT"] = str(1)
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from typing import Callable
|
|||
import dill as pickle
|
||||
from dill import PicklingWarning
|
||||
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
|
||||
warnings.filterwarnings("ignore", category=PicklingWarning)
|
||||
|
||||
|
||||
|
|
@ -148,18 +150,29 @@ def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is
|
|||
print(f"!######{test_stdout_tag}######!")
|
||||
|
||||
# Capture instance state after initialization
|
||||
if hasattr(args[0], "__dict__"):
|
||||
instance_state = args[
|
||||
0
|
||||
].__dict__ # self is always the first argument, this is ensured during instrumentation
|
||||
# self is always the first argument, this is ensured during instrumentation
|
||||
instance = args[0]
|
||||
if hasattr(instance, "__dict__"):
|
||||
instance_state = instance.__dict__
|
||||
elif hasattr(instance, "__slots__"):
|
||||
# For classes using __slots__, capture slot values
|
||||
instance_state = {
|
||||
slot: getattr(instance, slot, None) for slot in instance.__slots__ if hasattr(instance, slot)
|
||||
}
|
||||
else:
|
||||
raise ValueError("Instance state could not be captured.")
|
||||
# For C extension types or other special classes (e.g., Playwright's Page),
|
||||
# capture all non-private, non-callable attributes
|
||||
instance_state = {
|
||||
attr: getattr(instance, attr)
|
||||
for attr in dir(instance)
|
||||
if not attr.startswith("_") and not callable(getattr(instance, attr, None))
|
||||
}
|
||||
codeflash_cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)"
|
||||
)
|
||||
|
||||
# Write to sqlite
|
||||
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(instance_state)
|
||||
pickled_return_value = pickle.dumps(exception) if exception else PicklePatcher.dumps(instance_state)
|
||||
codeflash_cur.execute(
|
||||
"INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ HAS_TORCH = find_spec("torch") is not None
|
|||
HAS_JAX = find_spec("jax") is not None
|
||||
HAS_XARRAY = find_spec("xarray") is not None
|
||||
HAS_TENSORFLOW = find_spec("tensorflow") is not None
|
||||
HAS_NUMBA = find_spec("numba") is not None
|
||||
|
||||
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
|
||||
# These paths vary between test runs but are logically equivalent
|
||||
|
|
@ -156,6 +157,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
range,
|
||||
slice,
|
||||
OrderedDict,
|
||||
types.GenericAlias,
|
||||
),
|
||||
):
|
||||
return orig == new
|
||||
|
|
@ -255,6 +257,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
return False
|
||||
return True
|
||||
|
||||
# Handle mappingproxy (read-only dict view, commonly seen as class.__dict__)
|
||||
if isinstance(orig, types.MappingProxyType):
|
||||
return comparator(dict(orig), dict(new), superset_obj)
|
||||
|
||||
# Handle dict view types (dict_keys, dict_values, dict_items)
|
||||
# Use type name checking since these are not directly importable types
|
||||
type_name = type(orig).__name__
|
||||
|
|
@ -296,8 +302,8 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
# fails at "ufunc 'isfinite' not supported for the input types"
|
||||
return np.all([comparator(x, y, superset_obj) for x, y in zip(orig, new)])
|
||||
|
||||
if isinstance(orig, (np.floating, np.complex64, np.complex128)):
|
||||
return np.isclose(orig, new)
|
||||
if isinstance(orig, (np.floating, np.complexfloating)):
|
||||
return np.isclose(orig, new, equal_nan=True)
|
||||
|
||||
if isinstance(orig, (np.integer, np.bool_, np.byte)):
|
||||
return orig == new
|
||||
|
|
@ -383,6 +389,42 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
if isinstance(orig, torch.device):
|
||||
return orig == new
|
||||
|
||||
if HAS_NUMBA:
|
||||
import numba # type: ignore # noqa: PGH003
|
||||
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
|
||||
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
|
||||
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
|
||||
|
||||
# Handle numba typed List
|
||||
if isinstance(orig, NumbaList):
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||
|
||||
# Handle numba typed Dict
|
||||
if isinstance(orig, NumbaDict):
|
||||
if superset_obj:
|
||||
# Allow new dict to have more keys, but all orig keys must exist with equal values
|
||||
return all(key in new and comparator(orig[key], new[key], superset_obj) for key in orig)
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
for key in orig:
|
||||
if key not in new:
|
||||
return False
|
||||
if not comparator(orig[key], new[key], superset_obj):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Handle numba type objects (e.g., numba.int64, numba.float64, numba.Array, etc.)
|
||||
if isinstance(orig, numba.core.types.Type):
|
||||
return orig == new
|
||||
|
||||
# Handle numba JIT-compiled functions (CPUDispatcher, etc.)
|
||||
if isinstance(orig, Dispatcher):
|
||||
# Compare by identity of the underlying Python function
|
||||
# Two JIT functions are equal if they wrap the same Python function
|
||||
return orig.py_func is new.py_func
|
||||
|
||||
if HAS_PYRSISTENT:
|
||||
import pyrsistent # type: ignore # noqa: PGH003
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,14 @@ reprlib_repr.maxstring = 1500
|
|||
test_diff_repr = reprlib_repr.repr
|
||||
|
||||
|
||||
def safe_repr(obj: object) -> str:
|
||||
"""Safely get repr of an object, handling Mock objects with corrupted state."""
|
||||
try:
|
||||
return repr(obj)
|
||||
except (AttributeError, TypeError, RecursionError) as e:
|
||||
return f"<repr failed: {type(e).__name__}: {e}>"
|
||||
|
||||
|
||||
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
|
||||
# This is meant to be only called with test results for the first loop index
|
||||
if len(original_results) == 0 or len(candidate_results) == 0:
|
||||
|
|
@ -77,8 +85,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
|
|||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=TestDiffScope.RETURN_VALUE,
|
||||
original_value=test_diff_repr(repr(original_test_result.return_value)),
|
||||
candidate_value=test_diff_repr(repr(cdd_test_result.return_value)),
|
||||
original_value=test_diff_repr(safe_repr(original_test_result.return_value)),
|
||||
candidate_value=test_diff_repr(safe_repr(cdd_test_result.return_value)),
|
||||
test_src_code=original_test_result.id.get_src_code(original_test_result.file_name),
|
||||
candidate_pytest_error=cdd_pytest_error,
|
||||
original_pass=original_test_result.did_pass,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,14 @@ from codeflash.code_utils.code_utils import (
|
|||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
|
||||
from codeflash.models.models import (
|
||||
ConcurrencyMetrics,
|
||||
FunctionTestInvocation,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
TestType,
|
||||
VerificationType,
|
||||
)
|
||||
from codeflash.verification.coverage_utils import CoverageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -64,6 +71,54 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
|
|||
return function_throughput
|
||||
|
||||
|
||||
# Pattern for concurrency benchmark output:
|
||||
# !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
|
||||
_concurrency_pattern = re.compile(r"!@######CONC:([^:]*):([^:]*):([^:]*):([^:]*):([^:]*):(\d+):(\d+):(\d+)######@!")
|
||||
|
||||
|
||||
def parse_concurrency_metrics(test_results: TestResults, function_name: str) -> ConcurrencyMetrics | None:
|
||||
"""Parse concurrency benchmark results from test output.
|
||||
|
||||
Format: !@######CONC:module:class:test:function:loop_index:seq_time:conc_time:factor######@!
|
||||
|
||||
Returns ConcurrencyMetrics with:
|
||||
- sequential_time_ns: Total time for N sequential executions
|
||||
- concurrent_time_ns: Total time for N concurrent executions
|
||||
- concurrency_factor: N (number of concurrent executions)
|
||||
- concurrency_ratio: sequential_time / concurrent_time (higher = better concurrency)
|
||||
"""
|
||||
if not test_results.perf_stdout:
|
||||
return None
|
||||
|
||||
matches = _concurrency_pattern.findall(test_results.perf_stdout)
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Aggregate metrics for the target function
|
||||
total_seq, total_conc, factor, count = 0, 0, 0, 0
|
||||
for match in matches:
|
||||
# match[3] is function_name
|
||||
if len(match) >= 8 and match[3] == function_name:
|
||||
total_seq += int(match[5])
|
||||
total_conc += int(match[6])
|
||||
factor = int(match[7])
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
avg_seq = total_seq / count
|
||||
avg_conc = total_conc / count
|
||||
ratio = avg_seq / avg_conc if avg_conc > 0 else 1.0
|
||||
|
||||
return ConcurrencyMetrics(
|
||||
sequential_time_ns=int(avg_seq),
|
||||
concurrent_time_ns=int(avg_conc),
|
||||
concurrency_factor=factor,
|
||||
concurrency_ratio=ratio,
|
||||
)
|
||||
|
||||
|
||||
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
|
||||
"""Resolve test file path from pytest's test class path.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.19.1"
|
||||
__version__ = "0.20.0"
|
||||
|
|
|
|||
|
|
@ -1,191 +0,0 @@
|
|||
---
|
||||
title: "Flags Reference"
|
||||
description: "Complete reference for all Codeflash CLI flags and options"
|
||||
icon: "list"
|
||||
sidebarTitle: "Flags Reference"
|
||||
keywords: ["flags", "options", "arguments", "command line"]
|
||||
---
|
||||
|
||||
# Flags Reference
|
||||
|
||||
Complete reference for all Codeflash CLI flags and command-line options.
|
||||
|
||||
---
|
||||
|
||||
## Main Command Flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--file` | `PATH` | Optimize only this file |
|
||||
| `--function` | `NAME` | Optimize only this function (requires `--file`) |
|
||||
| `--all` | `[PATH]` | Optimize all functions. Optional path to start from |
|
||||
| `--replay-test` | `PATH` | Path to replay test file(s) |
|
||||
| `--benchmark` | flag | Enable benchmark mode |
|
||||
| `--no-pr` | flag | Don't create PR, update locally |
|
||||
| `--no-gen-tests` | flag | Don't generate tests |
|
||||
| `--no-draft` | flag | Skip draft PRs |
|
||||
| `--worktree` | flag | Use git worktree |
|
||||
| `--staging-review` | flag | Upload to staging |
|
||||
| `--verbose` / `-v` | flag | Verbose debug output |
|
||||
| `--verify-setup` | flag | Run setup verification |
|
||||
| `--version` | flag | Show version |
|
||||
|
||||
---
|
||||
|
||||
## Configuration Override Flags
|
||||
|
||||
Override settings from `pyproject.toml` via command line.
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--config-file` | `PATH` | Path to pyproject.toml |
|
||||
| `--module-root` | `PATH` | Python module root directory |
|
||||
| `--tests-root` | `PATH` | Tests directory |
|
||||
| `--benchmarks-root` | `PATH` | Benchmarks directory |
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Override config file location
|
||||
codeflash --file src/app.py --function main --config-file configs/pyproject.toml --no-pr
|
||||
|
||||
# Override module root
|
||||
codeflash --file src/app.py --function main --module-root src --no-pr
|
||||
|
||||
# Override tests root
|
||||
codeflash --file src/app.py --function main --tests-root tests/unit --no-pr
|
||||
|
||||
# Combine multiple overrides
|
||||
codeflash --file src/app.py --function main \
|
||||
--module-root src \
|
||||
--tests-root tests \
|
||||
--no-pr
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Override config file location
|
||||
codeflash --file src\app.py --function main --config-file configs\pyproject.toml --no-pr
|
||||
|
||||
# Override module root
|
||||
codeflash --file src\app.py --function main --module-root src --no-pr
|
||||
|
||||
# Override tests root
|
||||
codeflash --file src\app.py --function main --tests-root tests\unit --no-pr
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Accordion>
|
||||
|
||||
---
|
||||
|
||||
## Optimize Subcommand Flags
|
||||
|
||||
Flags specific to the `codeflash optimize` command.
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--output` | `PATH` | Trace file output path (default: `codeflash.trace`) |
|
||||
| `--timeout` | `INT` | Maximum trace time in seconds |
|
||||
| `--max-function-count` | `INT` | Max times to trace a function (default: 100) |
|
||||
| `--config-file-path` | `PATH` | Path to pyproject.toml |
|
||||
| `--trace-only` | flag | Only trace, don't optimize |
|
||||
|
||||
<Info>
|
||||
The `--output` flag specifies where to save the trace file. If not specified, it defaults to `codeflash.trace` in the current directory.
|
||||
</Info>
|
||||
|
||||
---
|
||||
|
||||
## Behavior Flags
|
||||
|
||||
Control how Codeflash behaves during optimization.
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--no-pr` | Run locally without creating a pull request |
|
||||
| `--no-gen-tests` | Use only existing tests, skip test generation |
|
||||
| `--no-draft` | Skip optimization for draft PRs (CI mode) |
|
||||
| `--worktree` | Use git worktree for isolated optimization |
|
||||
| `--staging-review` | Upload optimizations to staging for review |
|
||||
| `--verbose` / `-v` | Enable verbose debug logging |
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
```bash
|
||||
# Local optimization only
|
||||
codeflash --file src/app.py --function main --no-pr
|
||||
|
||||
# Use only existing tests
|
||||
codeflash --file src/app.py --function main --no-gen-tests --no-pr
|
||||
|
||||
# Enable verbose logging
|
||||
codeflash --file src/app.py --function main --verbose --no-pr
|
||||
|
||||
# Use worktree for isolation
|
||||
codeflash --file src/app.py --function main --worktree --no-pr
|
||||
|
||||
# Upload to staging
|
||||
codeflash --all --staging-review --no-pr
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
---
|
||||
|
||||
## Flag Combinations
|
||||
|
||||
Common flag combinations for different use cases:
|
||||
|
||||
### Local Development
|
||||
|
||||
```bash
|
||||
# Optimize locally with verbose output
|
||||
codeflash --file src/app.py --function main --no-pr --verbose
|
||||
```
|
||||
|
||||
### CI/CD Pipeline
|
||||
|
||||
```bash
|
||||
# Skip draft PRs and use existing tests only
|
||||
codeflash --all --no-draft --no-gen-tests
|
||||
```
|
||||
|
||||
### Debugging
|
||||
|
||||
```bash
|
||||
# Trace only with custom output and timeout
|
||||
codeflash optimize app.py --trace-only --output debug.trace --timeout 60
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```bash
|
||||
# Override multiple config settings
|
||||
codeflash --file src/app.py --function main \
|
||||
--module-root src \
|
||||
--tests-root tests/unit \
|
||||
--benchmarks-root tests/benchmarks \
|
||||
--no-pr
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Optimization Commands"
|
||||
icon="bullseye"
|
||||
href="/cli-reference/optimization"
|
||||
>
|
||||
Learn how to use optimization commands
|
||||
</Card>
|
||||
<Card
|
||||
title="Troubleshooting"
|
||||
icon="wrench"
|
||||
href="/cli-reference/troubleshooting"
|
||||
>
|
||||
Fix common issues
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
---
|
||||
title: "CLI Reference"
|
||||
description: "Complete command-line reference for Codeflash CLI commands, flags, and options"
|
||||
icon: "terminal"
|
||||
sidebarTitle: "Overview"
|
||||
keywords:
|
||||
[
|
||||
"CLI",
|
||||
"command line",
|
||||
"commands",
|
||||
"flags",
|
||||
"options",
|
||||
"reference",
|
||||
"terminal",
|
||||
]
|
||||
---
|
||||
|
||||
# Codeflash CLI Reference
|
||||
|
||||
Complete command-line reference for all Codeflash commands, flags, and options with practical examples you can run directly in your terminal.
|
||||
|
||||
<Info>
|
||||
**Prerequisites** - Ensure Codeflash is installed in your Python environment
|
||||
and you have a configured `pyproject.toml` in your project.
|
||||
</Info>
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Activate virtual environment (if using one)
|
||||
source .venv/bin/activate
|
||||
|
||||
# Verify installation
|
||||
codeflash --version
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Activate virtual environment (if using one)
|
||||
.venv\Scripts\activate
|
||||
|
||||
# Verify installation
|
||||
codeflash --version
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### 1. First-Time Setup
|
||||
|
||||
<Steps>
|
||||
<Step title="Install Codeflash">
|
||||
```bash
|
||||
pip install codeflash
|
||||
```
|
||||
</Step>
|
||||
<Step title="Initialize Project">
|
||||
```bash
|
||||
codeflash init
|
||||
```
|
||||
</Step>
|
||||
<Step title="Verify Setup">
|
||||
```bash
|
||||
codeflash --verify-setup
|
||||
```
|
||||
</Step>
|
||||
<Step title="Run First Optimization">
|
||||
```bash
|
||||
codeflash --file src/main.py --function my_function --no-pr
|
||||
```
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
### 2. Optimize a Workflow
|
||||
|
||||
<Steps>
|
||||
<Step title="Trace Your Script">
|
||||
```bash
|
||||
codeflash optimize my_script.py --arg1 value1
|
||||
```
|
||||
</Step>
|
||||
<Step title="Review Optimizations">
|
||||
Check the generated PR or local changes for optimization suggestions.
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
### 3. CI/CD Integration
|
||||
|
||||
<Steps>
|
||||
<Step title="Set Up GitHub Actions">
|
||||
```bash
|
||||
codeflash init-actions
|
||||
```
|
||||
</Step>
|
||||
<Step title="Merge the Workflow PR">
|
||||
Review and merge the generated GitHub Actions workflow.
|
||||
</Step>
|
||||
<Step title="Automatic Optimization">
|
||||
Codeflash will now optimize code in every PR automatically!
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
---
|
||||
|
||||
## Help & Version
|
||||
|
||||
```bash
|
||||
# Display version
|
||||
codeflash --version
|
||||
|
||||
# Main help
|
||||
codeflash --help
|
||||
|
||||
# Subcommand help
|
||||
codeflash optimize --help
|
||||
codeflash init --help
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Documentation Structure
|
||||
|
||||
This CLI reference is organized into the following sections:
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Setup Commands"
|
||||
icon="wrench"
|
||||
href="/cli-reference/setup"
|
||||
>
|
||||
Initialize projects, set up GitHub Actions, and verify installation
|
||||
</Card>
|
||||
<Card
|
||||
title="Optimization Commands"
|
||||
icon="bullseye"
|
||||
href="/cli-reference/optimization"
|
||||
>
|
||||
Optimize single functions or entire codebases
|
||||
</Card>
|
||||
<Card
|
||||
title="Tracing & Workflows"
|
||||
icon="route"
|
||||
href="/cli-reference/tracing"
|
||||
>
|
||||
Trace script execution and optimize based on real usage
|
||||
</Card>
|
||||
<Card
|
||||
title="Flags Reference"
|
||||
icon="list"
|
||||
href="/cli-reference/flags"
|
||||
>
|
||||
Complete reference for all command-line flags
|
||||
</Card>
|
||||
<Card
|
||||
title="Troubleshooting"
|
||||
icon="wrench"
|
||||
href="/cli-reference/troubleshooting"
|
||||
>
|
||||
Solutions for common CLI issues
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Optimize a Function"
|
||||
icon="bullseye"
|
||||
href="/optimizing-with-codeflash/one-function"
|
||||
>
|
||||
Learn how to optimize individual functions
|
||||
</Card>
|
||||
<Card
|
||||
title="Trace & Optimize"
|
||||
icon="route"
|
||||
href="/optimizing-with-codeflash/trace-and-optimize"
|
||||
>
|
||||
Optimize entire workflows with tracing
|
||||
</Card>
|
||||
<Card
|
||||
title="GitHub Actions"
|
||||
icon="github"
|
||||
href="/optimizing-with-codeflash/codeflash-github-actions"
|
||||
>
|
||||
Set up continuous optimization
|
||||
</Card>
|
||||
<Card
|
||||
title="Configuration"
|
||||
icon="gear"
|
||||
href="/configuration"
|
||||
>
|
||||
Advanced configuration options
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
@ -1,172 +0,0 @@
|
|||
---
|
||||
title: "Optimization Commands"
|
||||
description: "Optimize single functions or entire codebases with Codeflash CLI"
|
||||
icon: "bullseye"
|
||||
sidebarTitle: "Optimization Commands"
|
||||
keywords: ["optimization", "function", "file", "all", "commands"]
|
||||
---
|
||||
|
||||
# Optimization Commands
|
||||
|
||||
Commands for optimizing individual functions or entire codebases.
|
||||
|
||||
---
|
||||
|
||||
## Optimize a Single Function
|
||||
|
||||
Target a specific function in a file for optimization.
|
||||
|
||||
```bash
|
||||
codeflash --file <path/to/file.py> --function <function_name>
|
||||
```
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Basic optimization (creates PR)
|
||||
codeflash --file src/utils.py --function calculate_metrics
|
||||
|
||||
# Local optimization only (no PR)
|
||||
codeflash --file src/utils.py --function calculate_metrics --no-pr
|
||||
|
||||
# With verbose output
|
||||
codeflash --file src/utils.py --function calculate_metrics --no-pr --verbose
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Basic optimization (creates PR)
|
||||
codeflash --file src\utils.py --function calculate_metrics
|
||||
|
||||
# Local optimization only (no PR)
|
||||
codeflash --file src\utils.py --function calculate_metrics --no-pr
|
||||
|
||||
# With verbose output
|
||||
codeflash --file src\utils.py --function calculate_metrics --no-pr --verbose
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Accordion>
|
||||
|
||||
<Warning>
|
||||
**Important**: The file must be within your configured `module-root`
|
||||
directory. Files outside `module-root` will be ignored with "Functions outside
|
||||
module-root" message.
|
||||
</Warning>
|
||||
|
||||
---
|
||||
|
||||
## Optimize All Functions
|
||||
|
||||
Optimize all functions in your entire codebase or a specific directory.
|
||||
|
||||
```bash
|
||||
# Optimize entire codebase
|
||||
codeflash --all
|
||||
|
||||
# Optimize specific directory
|
||||
codeflash --all src/core/
|
||||
```
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Optimize all (creates PRs)
|
||||
codeflash --all
|
||||
|
||||
# Optimize all locally (no PRs)
|
||||
codeflash --all --no-pr
|
||||
|
||||
# Optimize specific directory
|
||||
codeflash --all src/algorithms/ --no-pr
|
||||
|
||||
# Skip draft PRs in CI
|
||||
codeflash --all --no-draft
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Optimize all (creates PRs)
|
||||
codeflash --all
|
||||
|
||||
# Optimize all locally (no PRs)
|
||||
codeflash --all --no-pr
|
||||
|
||||
# Optimize specific directory
|
||||
codeflash --all src\algorithms\ --no-pr
|
||||
|
||||
# Skip draft PRs in CI
|
||||
codeflash --all --no-draft
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Accordion>
|
||||
|
||||
<Info>
|
||||
When using `--all`, Codeflash will:
|
||||
- Discover all optimizable functions in your codebase
|
||||
- Create separate PRs for each function (or update locally with `--no-pr`)
|
||||
- Process functions in batches to avoid overwhelming your repository
|
||||
</Info>
|
||||
|
||||
---
|
||||
|
||||
## Benchmark Mode
|
||||
|
||||
Optimize code based on performance benchmarks using pytest-benchmark format.
|
||||
|
||||
```bash
|
||||
codeflash --file <file.py> --benchmark --benchmarks-root <path>
|
||||
```
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# With benchmarks-root flag
|
||||
codeflash --file src/core.py --benchmark --benchmarks-root tests/benchmarks --no-pr
|
||||
|
||||
# If benchmarks-root is in pyproject.toml
|
||||
codeflash --file src/core.py --benchmark --no-pr
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# With benchmarks-root flag
|
||||
codeflash --file src\core.py --benchmark --benchmarks-root tests\benchmarks --no-pr
|
||||
|
||||
# If benchmarks-root is in pyproject.toml
|
||||
codeflash --file src\core.py --benchmark --no-pr
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Accordion>
|
||||
|
||||
<Warning>
|
||||
The `--benchmarks-root` directory must exist and be configured either via
|
||||
`pyproject.toml` or the command-line flag.
|
||||
</Warning>
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Tracing & Workflows"
|
||||
icon="route"
|
||||
href="/cli-reference/tracing"
|
||||
>
|
||||
Learn about trace-based optimization
|
||||
</Card>
|
||||
<Card
|
||||
title="Flags Reference"
|
||||
icon="list"
|
||||
href="/cli-reference/flags"
|
||||
>
|
||||
Complete flag reference
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
---
|
||||
title: "Setup Commands"
|
||||
description: "Initialize projects, set up GitHub Actions, and verify Codeflash installation"
|
||||
icon: "wrench"
|
||||
sidebarTitle: "Setup Commands"
|
||||
keywords: ["setup", "init", "installation", "github actions", "verify"]
|
||||
---
|
||||
|
||||
# Setup Commands
|
||||
|
||||
Commands for initializing Codeflash in your project, setting up continuous optimization, and verifying your installation.
|
||||
|
||||
---
|
||||
|
||||
## `codeflash init`
|
||||
|
||||
Initialize Codeflash for your Python project. This creates the configuration in `pyproject.toml`.
|
||||
|
||||
<CodeGroup>
|
||||
```bash Basic
|
||||
codeflash init
|
||||
```
|
||||
|
||||
```bash With Formatter Override
|
||||
codeflash init --override-formatter-check
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
<Tip>
|
||||
The `init` command will guide you through an interactive setup process,
|
||||
including API key configuration, module selection, and GitHub App
|
||||
installation.
|
||||
</Tip>
|
||||
|
||||
**What it does:**
|
||||
|
||||
- Prompts for your Python module directory (`module-root`)
|
||||
- Prompts for your test directory (`tests-root`)
|
||||
- Configures code formatter preferences
|
||||
- Sets up telemetry preferences
|
||||
- Optionally installs the Codeflash VS Code extension
|
||||
- Optionally sets up GitHub Actions workflow
|
||||
|
||||
---
|
||||
|
||||
## `codeflash init-actions`
|
||||
|
||||
Set up GitHub Actions workflow for continuous optimization on every pull request.
|
||||
|
||||
```bash
|
||||
codeflash init-actions
|
||||
```
|
||||
|
||||
**What it does:**
|
||||
|
||||
- Creates a workflow file in `.github/workflows/`
|
||||
- Opens a PR with the workflow configuration
|
||||
- Requires the Codeflash GitHub App to be installed
|
||||
|
||||
<Warning>
|
||||
This command requires the Codeflash GitHub App to be installed on your repository. If you haven't installed it, you'll be prompted with a link to do so.
|
||||
</Warning>
|
||||
|
||||
---
|
||||
|
||||
## `codeflash vscode-install`
|
||||
|
||||
Install the Codeflash extension for VS Code, Cursor, or Windsurf.
|
||||
|
||||
```bash
|
||||
codeflash vscode-install
|
||||
```
|
||||
|
||||
**What it does:**
|
||||
|
||||
- Detects which editor you're using (VS Code, Cursor, or Windsurf)
|
||||
- Downloads and installs the appropriate extension
|
||||
- Works with both Marketplace and Open VSX sources
|
||||
|
||||
<Tip>
|
||||
This command is also run automatically during `codeflash init` if you choose to install the extension.
|
||||
</Tip>
|
||||
|
||||
---
|
||||
|
||||
## `codeflash --verify-setup`
|
||||
|
||||
Verify your Codeflash installation by running a sample optimization.
|
||||
|
||||
```bash
|
||||
codeflash --verify-setup
|
||||
```
|
||||
|
||||
**What it does:**
|
||||
|
||||
- Creates a temporary demo file
|
||||
- Runs a sample optimization
|
||||
- Verifies all components are working correctly
|
||||
- Cleans up the demo file afterward
|
||||
|
||||
<Note>
|
||||
This command takes about 3 minutes to complete. It's a great way to ensure everything is set up correctly before optimizing your actual code.
|
||||
</Note>
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Optimization Commands"
|
||||
icon="bullseye"
|
||||
href="/cli-reference/optimization"
|
||||
>
|
||||
Learn how to optimize functions
|
||||
</Card>
|
||||
<Card
|
||||
title="Flags Reference"
|
||||
icon="list"
|
||||
href="/cli-reference/flags"
|
||||
>
|
||||
Complete flag reference
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
---
|
||||
title: "Tracing & Workflows"
|
||||
description: "Trace script execution and optimize functions based on real-world usage"
|
||||
icon: "route"
|
||||
sidebarTitle: "Tracing & Workflows"
|
||||
keywords: ["tracing", "optimize", "workflow", "replay test", "pytest"]
|
||||
---
|
||||
|
||||
# Tracing & Workflows
|
||||
|
||||
Trace Python script execution and optimize functions based on real-world usage patterns.
|
||||
|
||||
---
|
||||
|
||||
## `codeflash optimize`
|
||||
|
||||
Trace a Python script's execution and optimize functions based on real-world usage.
|
||||
|
||||
```bash
|
||||
codeflash optimize <script.py> [script_args]
|
||||
```
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Basic trace and optimize
|
||||
codeflash optimize app.py
|
||||
|
||||
# With script arguments
|
||||
codeflash optimize process.py --input data.csv --output results.json
|
||||
|
||||
# Custom trace output file
|
||||
codeflash optimize app.py --output custom_trace.trace
|
||||
|
||||
# With timeout (30 seconds)
|
||||
codeflash optimize long_running_script.py --timeout 30
|
||||
|
||||
# Limit function trace count
|
||||
codeflash optimize app.py --max-function-count 50
|
||||
|
||||
# Specify config file
|
||||
codeflash optimize app.py --config-file-path pyproject.toml
|
||||
|
||||
# Local only (no PR)
|
||||
codeflash optimize app.py --no-pr
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Basic trace and optimize
|
||||
codeflash optimize app.py
|
||||
|
||||
# With script arguments
|
||||
codeflash optimize process.py --input data.csv --output results.json
|
||||
|
||||
# Custom trace output file
|
||||
codeflash optimize app.py --output custom_trace.trace
|
||||
|
||||
# With timeout (30 seconds)
|
||||
codeflash optimize long_running_script.py --timeout 30
|
||||
|
||||
# Limit function trace count
|
||||
codeflash optimize app.py --max-function-count 50
|
||||
|
||||
# Specify config file
|
||||
codeflash optimize app.py --config-file-path pyproject.toml
|
||||
|
||||
# Local only (no PR)
|
||||
codeflash optimize app.py --no-pr
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Accordion>
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Runs your script with the provided arguments
|
||||
2. Traces all function calls during execution
|
||||
3. Identifies which functions are called and how often
|
||||
4. Generates replay tests based on actual usage
|
||||
5. Optimizes the traced functions
|
||||
|
||||
---
|
||||
|
||||
## Trace with pytest
|
||||
|
||||
Optimize functions called during pytest test execution.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Trace pytest tests
|
||||
codeflash optimize -m pytest tests/
|
||||
|
||||
# Trace specific test file
|
||||
codeflash optimize -m pytest tests/test_core.py
|
||||
|
||||
# With pytest arguments
|
||||
codeflash optimize -m pytest tests/ -v --tb=short
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Trace pytest tests
|
||||
codeflash optimize -m pytest tests\
|
||||
|
||||
# Trace specific test file
|
||||
codeflash optimize -m pytest tests\test_core.py
|
||||
|
||||
# With pytest arguments
|
||||
codeflash optimize -m pytest tests\ -v --tb=short
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Tip>
|
||||
Tracing pytest tests is great for optimizing functions that are heavily used in your test suite, ensuring optimizations work correctly with your existing tests.
|
||||
</Tip>
|
||||
|
||||
---
|
||||
|
||||
## Trace Only (Generate Replay Tests)
|
||||
|
||||
Create trace files and replay tests without running optimization.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Trace only - generates replay test
|
||||
codeflash optimize app.py --output trace_file.trace --trace-only
|
||||
|
||||
# Then optimize using the replay test
|
||||
codeflash --replay-test tests/test_app_py__replay_test_0.py --no-pr
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Trace only - generates replay test
|
||||
codeflash optimize app.py --output trace_file.trace --trace-only
|
||||
|
||||
# Then optimize using the replay test
|
||||
codeflash --replay-test tests\test_app_py__replay_test_0.py --no-pr
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Note>
|
||||
**Replay test naming**: Files are named based on the traced script path. For
|
||||
`src/app.py`, the replay test will be named like
|
||||
`test_srcapp_py__replay_test_0.py`.
|
||||
</Note>
|
||||
|
||||
**Use cases for trace-only:**
|
||||
|
||||
- Generate replay tests for later optimization
|
||||
- Debug tracing issues without running full optimization
|
||||
- Create reusable test cases from script execution
|
||||
|
||||
---
|
||||
|
||||
## Replay Test Optimization
|
||||
|
||||
Optimize functions using previously generated replay tests.
|
||||
|
||||
```bash
|
||||
codeflash --replay-test <path/to/replay_test.py>
|
||||
```
|
||||
|
||||
<Accordion title="Complete Examples">
|
||||
<Tabs>
|
||||
<Tab title="Linux/macOS">
|
||||
```bash
|
||||
# Optimize using replay test
|
||||
codeflash --replay-test tests/test_app_py__replay_test_0.py --no-pr
|
||||
|
||||
# Multiple replay tests
|
||||
codeflash --replay-test tests/test_*.py --no-pr
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Windows">
|
||||
```powershell
|
||||
# Optimize using replay test
|
||||
codeflash --replay-test tests\test_app_py__replay_test_0.py --no-pr
|
||||
|
||||
# Multiple replay tests (use Get-ChildItem for globbing)
|
||||
codeflash --replay-test (Get-ChildItem tests\test_*.py) --no-pr
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Accordion>
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Optimization Commands"
|
||||
icon="bullseye"
|
||||
href="/cli-reference/optimization"
|
||||
>
|
||||
Learn about function optimization
|
||||
</Card>
|
||||
<Card
|
||||
title="Flags Reference"
|
||||
icon="list"
|
||||
href="/cli-reference/flags"
|
||||
>
|
||||
Complete flag reference
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
---
|
||||
title: "CLI Troubleshooting"
|
||||
description: "Solutions for common Codeflash CLI issues and errors"
|
||||
icon: "wrench"
|
||||
sidebarTitle: "Troubleshooting"
|
||||
keywords: ["troubleshooting", "errors", "issues", "problems", "debugging"]
|
||||
---
|
||||
|
||||
# CLI Troubleshooting
|
||||
|
||||
Solutions for common issues when using the Codeflash CLI.
|
||||
|
||||
---
|
||||
|
||||
## Common Issues
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="'Functions outside module-root' Error">
|
||||
**Problem**: Function not found because file is outside `module-root`.
|
||||
|
||||
**Solution**: Ensure your file is within the `module-root` directory specified in `pyproject.toml`.
|
||||
|
||||
```bash
|
||||
# Check your module-root
|
||||
grep "module-root" pyproject.toml
|
||||
|
||||
# Use the correct path (e.g., if module-root is "src")
|
||||
codeflash --file src/myfile.py --function my_function --no-pr
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="'benchmarks-root must be specified' Error">
|
||||
**Problem**: Using `--benchmark` without specifying benchmarks directory.
|
||||
|
||||
**Solution**: Either add `benchmarks-root` to `pyproject.toml` or use the flag:
|
||||
|
||||
```bash
|
||||
codeflash --file src/app.py --benchmark --benchmarks-root tests/benchmarks --no-pr
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Replay Test File Not Found">
|
||||
**Problem**: Replay test filename doesn't match expected path.
|
||||
|
||||
**Solution**: Replay tests include the module path in their name. Check the actual filename:
|
||||
|
||||
```bash
|
||||
# Linux/macOS
|
||||
ls tests/test_*replay*.py
|
||||
|
||||
# Windows
|
||||
dir tests\test_*replay*.py
|
||||
```
|
||||
|
||||
<Note>
|
||||
Replay test files are named based on the traced script path. For `src/app.py`,
|
||||
the replay test will be named like `test_srcapp_py__replay_test_0.py`.
|
||||
</Note>
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="GitHub App Required">
|
||||
**Problem**: PR creation fails due to missing GitHub App.
|
||||
|
||||
**Solution**: Install the Codeflash GitHub App or use `--no-pr` for local optimization:
|
||||
|
||||
```bash
|
||||
# Local optimization
|
||||
codeflash --file src/app.py --function main --no-pr
|
||||
|
||||
# Or install the GitHub App
|
||||
# https://github.com/apps/codeflash-ai/installations/select_target
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Module Not Found Errors">
|
||||
**Problem**: Codeflash can't find your Python modules.
|
||||
|
||||
**Solution**:
|
||||
|
||||
1. Verify `module-root` is correctly set in `pyproject.toml`
|
||||
2. Ensure you're running from the project root
|
||||
3. Check that your Python environment has all dependencies installed
|
||||
|
||||
```bash
|
||||
# Verify module-root
|
||||
cat pyproject.toml | grep module-root
|
||||
|
||||
# Check Python path
|
||||
python -c "import sys; print(sys.path)"
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Test Generation Fails">
|
||||
**Problem**: Codeflash can't generate tests for your function.
|
||||
|
||||
**Solution**:
|
||||
|
||||
1. Ensure your function has a return statement
|
||||
2. Check that the function is not a property or class method with special decorators
|
||||
3. Use `--no-gen-tests` to skip test generation and use existing tests only
|
||||
|
||||
```bash
|
||||
codeflash --file src/app.py --function main --no-gen-tests --no-pr
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Optimization Timeout">
|
||||
**Problem**: Optimization takes too long or times out.
|
||||
|
||||
**Solution**:
|
||||
|
||||
1. Use `--verbose` to see what's happening
|
||||
2. For tracing, use `--timeout` to limit trace duration
|
||||
3. For large functions, consider breaking them down
|
||||
|
||||
```bash
|
||||
# Limit trace time
|
||||
codeflash optimize app.py --timeout 30
|
||||
|
||||
# See detailed progress
|
||||
codeflash --file src/app.py --function main --verbose --no-pr
|
||||
```
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
---
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you're still experiencing issues:
|
||||
|
||||
1. **Check the logs**: Use `--verbose` flag to see detailed output
|
||||
2. **Verify setup**: Run `codeflash --verify-setup` to check your installation
|
||||
3. **Check configuration**: Ensure `pyproject.toml` is correctly configured
|
||||
4. **View help**: Run `codeflash --help` or `codeflash <command> --help`
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card
|
||||
title="Setup Commands"
|
||||
icon="wrench"
|
||||
href="/cli-reference/setup"
|
||||
>
|
||||
Review setup and initialization
|
||||
</Card>
|
||||
<Card
|
||||
title="Flags Reference"
|
||||
icon="list"
|
||||
href="/cli-reference/flags"
|
||||
>
|
||||
Complete flag reference
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
@ -18,33 +18,13 @@
|
|||
{
|
||||
"tab": "Documentation",
|
||||
"groups": [
|
||||
{
|
||||
"group": "🚀 Quickstart",
|
||||
"pages": ["getting-started/local-installation"]
|
||||
},
|
||||
{
|
||||
"group": "🏠 Overview",
|
||||
"pages": ["index"]
|
||||
},
|
||||
{
|
||||
"group": "📖 Codeflash CLI",
|
||||
"pages": [
|
||||
"cli-reference/index",
|
||||
"cli-reference/setup",
|
||||
"cli-reference/optimization",
|
||||
"cli-reference/tracing",
|
||||
"cli-reference/flags",
|
||||
"cli-reference/troubleshooting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "🛠 IDE Extension",
|
||||
"pages": [
|
||||
"editor-plugins/vscode/index",
|
||||
"editor-plugins/vscode/features",
|
||||
"editor-plugins/vscode/configuration",
|
||||
"editor-plugins/vscode/troubleshooting"
|
||||
]
|
||||
"group": "🚀 Quickstart",
|
||||
"pages": ["getting-started/local-installation"]
|
||||
},
|
||||
{
|
||||
"group": "⚡ Optimizing with Codeflash",
|
||||
|
|
@ -62,6 +42,15 @@
|
|||
"optimizing-with-codeflash/review-optimizations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "🛠 IDE Extension",
|
||||
"pages": [
|
||||
"editor-plugins/vscode/index",
|
||||
"editor-plugins/vscode/features",
|
||||
"editor-plugins/vscode/configuration",
|
||||
"editor-plugins/vscode/troubleshooting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "🧠 Core Concepts",
|
||||
"pages": [
|
||||
|
|
|
|||
|
|
@ -146,7 +146,4 @@ When configuration issues are detected, the extension displays clear error messa
|
|||
<Card title="Project Configuration" icon="file-code" href="/configuration">
|
||||
Complete pyproject.toml reference
|
||||
</Card>
|
||||
<Card title="Codeflash CLI" icon="terminal" href="/cli-reference/index">
|
||||
Command-line options
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
|
|
|||
|
|
@ -204,7 +204,6 @@ The extension works alongside the Codeflash CLI. You can:
|
|||
- **Use extension for interactive work** — Optimize individual functions as you code
|
||||
- **Mix both** — The extension picks up CLI results when you return to the editor
|
||||
|
||||
For CLI documentation, see the [Codeflash CLI](/cli-reference/index).
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -220,9 +219,6 @@ For CLI documentation, see the [Codeflash CLI](/cli-reference/index).
|
|||
<Card title="Troubleshooting" icon="wrench" href="/editor-plugins/vscode/troubleshooting">
|
||||
Fix common issues
|
||||
</Card>
|
||||
<Card title="Codeflash CLI" icon="terminal" href="/cli-reference/index">
|
||||
Command-line interface docs
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
---
|
||||
|
|
|
|||
|
|
@ -208,8 +208,5 @@ If you're still experiencing issues:
|
|||
<Card title="Configuration" icon="gear" href="/editor-plugins/vscode/configuration">
|
||||
Customize extension settings
|
||||
</Card>
|
||||
<Card title="Codeflash CLI" icon="terminal" href="/cli-reference/index">
|
||||
Command-line interface docs
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
config = TestConfig(
|
||||
file_path="main.py",
|
||||
min_improvement_x=0.1,
|
||||
expected_acceptance_reason="concurrency",
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="retry_with_backoff",
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class TestConfig:
|
|||
benchmarks_root: Optional[pathlib.Path] = None
|
||||
use_worktree: bool = False
|
||||
no_gen_tests: bool = False
|
||||
expected_acceptance_reason: Optional[str] = None # "runtime", "throughput", "concurrency"
|
||||
|
||||
|
||||
def clear_directory(directory_path: str | pathlib.Path) -> None:
|
||||
|
|
@ -176,7 +177,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
|
|||
logging.error("Failed to find performance improvement message")
|
||||
return False
|
||||
|
||||
improvement_match = re.search(r"📈 ([\d,]+)% improvement", stdout)
|
||||
improvement_match = re.search(r"📈 ([\d,]+)% (?:(\w+) )?improvement", stdout)
|
||||
if not improvement_match:
|
||||
logging.error("Could not find improvement percentage in output")
|
||||
return False
|
||||
|
|
@ -193,6 +194,15 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
|
|||
logging.error(f"Performance improvement rate {improvement_x}x not above {config.min_improvement_x}x")
|
||||
return False
|
||||
|
||||
if config.expected_acceptance_reason is not None:
|
||||
actual_reason = improvement_match.group(2)
|
||||
if not actual_reason:
|
||||
logging.error("Could not find acceptance reason type in output")
|
||||
return False
|
||||
if actual_reason != config.expected_acceptance_reason:
|
||||
logging.error(f"Expected acceptance reason '{config.expected_acceptance_reason}', got '{actual_reason}'")
|
||||
return False
|
||||
|
||||
if config.expected_unit_tests_count is not None:
|
||||
# Match the global test discovery message from optimizer.py which counts test invocations
|
||||
# Format: "Discovered X existing unit tests and Y replay tests in Z.Zs at /path/to/tests"
|
||||
|
|
|
|||
304
tests/test_async_concurrency_decorator.py
Normal file
304
tests/test_async_concurrency_decorator.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import codeflash_concurrency_async
|
||||
from codeflash.models.models import ConcurrencyMetrics, TestResults
|
||||
from codeflash.verification.parse_test_output import parse_concurrency_metrics
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestConcurrencyAsyncDecorator:
|
||||
"""Integration tests for codeflash_concurrency_async decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def concurrency_env_setup(self, request):
|
||||
"""Set up environment variables for concurrency testing."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestConcurrencyAsyncDecorator",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CONCURRENCY_FACTOR": "5", # Use smaller factor for faster tests
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_nonblocking_function(self, concurrency_env_setup, capsys):
|
||||
"""Test that non-blocking async functions show high concurrency ratio."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def nonblocking_sleep(duration: float) -> str:
|
||||
await asyncio.sleep(duration)
|
||||
return "done"
|
||||
|
||||
result = await nonblocking_sleep(0.01)
|
||||
|
||||
assert result == "done"
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
# Verify the output format
|
||||
assert "!@######CONC:" in output
|
||||
assert "######@!" in output
|
||||
|
||||
# Parse the output manually to verify format
|
||||
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
|
||||
assert len(lines) == 1
|
||||
|
||||
line = lines[0]
|
||||
# Format: !@######CONC:{test_module}:{test_class}:{test_function}:{function_name}:{loop_index}:{seq_time}:{conc_time}:{factor}######@!
|
||||
assert "nonblocking_sleep" in line
|
||||
assert ":5######@!" in line # concurrency factor
|
||||
|
||||
# Extract timing values
|
||||
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
# parts should be: [test_module, test_class, test_function, function_name, loop_index, seq_time, conc_time, factor]
|
||||
assert len(parts) == 8
|
||||
|
||||
seq_time = int(parts[5])
|
||||
conc_time = int(parts[6])
|
||||
factor = int(parts[7])
|
||||
|
||||
assert seq_time > 0
|
||||
assert conc_time > 0
|
||||
assert factor == 5
|
||||
|
||||
# For non-blocking async, concurrent time should be much less than sequential
|
||||
# Sequential runs 5 iterations of 10ms = ~50ms
|
||||
# Concurrent runs 5 iterations in parallel = ~10ms
|
||||
# So ratio should be around 5 (with some overhead tolerance)
|
||||
ratio = seq_time / conc_time if conc_time > 0 else 1.0
|
||||
assert ratio > 2.0, f"Non-blocking function should have ratio > 2.0, got {ratio}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_blocking_function(self, concurrency_env_setup, capsys):
|
||||
"""Test that blocking functions show low concurrency ratio (~1.0)."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def blocking_sleep(duration: float) -> str:
|
||||
time.sleep(duration) # Blocking sleep
|
||||
return "done"
|
||||
|
||||
result = await blocking_sleep(0.005) # 5ms blocking
|
||||
|
||||
assert result == "done"
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
assert "!@######CONC:" in output
|
||||
|
||||
lines = [line for line in output.strip().split("\n") if "!@######CONC:" in line]
|
||||
assert len(lines) == 1
|
||||
|
||||
line = lines[0]
|
||||
parts = line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
assert len(parts) == 8
|
||||
|
||||
seq_time = int(parts[5])
|
||||
conc_time = int(parts[6])
|
||||
|
||||
# For blocking code, sequential and concurrent times should be similar
|
||||
# Because time.sleep blocks the entire event loop
|
||||
ratio = seq_time / conc_time if conc_time > 0 else 1.0
|
||||
# Blocking code should have ratio close to 1.0 (within reasonable tolerance)
|
||||
assert ratio < 2.0, f"Blocking function should have ratio < 2.0, got {ratio}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_decorator_with_computation(self, concurrency_env_setup, capsys):
|
||||
"""Test concurrency with CPU-bound computation."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def compute_intensive(n: int) -> int:
|
||||
# CPU-bound work (blocked by GIL in concurrent execution)
|
||||
total = 0
|
||||
for i in range(n):
|
||||
total += i * i
|
||||
return total
|
||||
|
||||
result = await compute_intensive(10000)
|
||||
|
||||
assert result == sum(i * i for i in range(10000))
|
||||
|
||||
captured = capsys.readouterr()
|
||||
output = captured.out
|
||||
|
||||
assert "!@######CONC:" in output
|
||||
assert "compute_intensive" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestParseConcurrencyMetrics:
|
||||
"""Integration tests for parse_concurrency_metrics function."""
|
||||
|
||||
def test_parse_concurrency_metrics_from_real_output(self):
|
||||
"""Test parsing concurrency metrics from simulated stdout."""
|
||||
# Simulate stdout from codeflash_concurrency_async decorator
|
||||
perf_stdout = """Some other output
|
||||
!@######CONC:test_module:TestClass:test_func:my_async_func:1:50000000:10000000:5######@!
|
||||
More output here
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_async_func")
|
||||
|
||||
assert metrics is not None
|
||||
assert isinstance(metrics, ConcurrencyMetrics)
|
||||
assert metrics.sequential_time_ns == 50000000
|
||||
assert metrics.concurrent_time_ns == 10000000
|
||||
assert metrics.concurrency_factor == 5
|
||||
assert metrics.concurrency_ratio == 5.0 # 50M / 10M = 5.0
|
||||
|
||||
def test_parse_concurrency_metrics_multiple_entries(self):
|
||||
"""Test parsing when multiple concurrency entries exist."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:target_func:1:40000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:target_func:2:60000000:10000000:5######@!
|
||||
!@######CONC:test_module:TestClass:test_func:other_func:1:30000000:15000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "target_func")
|
||||
|
||||
assert metrics is not None
|
||||
# Should average the two entries for target_func
|
||||
# (40M + 60M) / 2 = 50M seq, (10M + 10M) / 2 = 10M conc
|
||||
assert metrics.sequential_time_ns == 50000000
|
||||
assert metrics.concurrent_time_ns == 10000000
|
||||
assert metrics.concurrency_ratio == 5.0
|
||||
|
||||
def test_parse_concurrency_metrics_no_match(self):
|
||||
"""Test parsing when function name doesn't match."""
|
||||
perf_stdout = """!@######CONC:test_module:TestClass:test_func:other_func:1:50000000:10000000:5######@!
|
||||
"""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=perf_stdout,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "nonexistent_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
def test_parse_concurrency_metrics_empty_stdout(self):
|
||||
"""Test parsing with empty stdout."""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout="",
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
def test_parse_concurrency_metrics_none_stdout(self):
|
||||
"""Test parsing with None stdout."""
|
||||
test_results = TestResults(
|
||||
test_results=[],
|
||||
perf_stdout=None,
|
||||
)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "any_func")
|
||||
|
||||
assert metrics is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
class TestConcurrencyRatioComparison:
|
||||
"""Test comparing blocking vs non-blocking concurrency ratios."""
|
||||
|
||||
@pytest.fixture
|
||||
def comparison_env_setup(self, request):
|
||||
"""Set up environment variables for comparison testing."""
|
||||
original_env = {}
|
||||
test_env = {
|
||||
"CODEFLASH_LOOP_INDEX": "1",
|
||||
"CODEFLASH_TEST_MODULE": __name__,
|
||||
"CODEFLASH_TEST_CLASS": "TestConcurrencyRatioComparison",
|
||||
"CODEFLASH_TEST_FUNCTION": request.node.name,
|
||||
"CODEFLASH_CONCURRENCY_FACTOR": "10",
|
||||
}
|
||||
|
||||
for key, value in test_env.items():
|
||||
original_env[key] = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
|
||||
yield test_env
|
||||
|
||||
for key, original_value in original_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocking_vs_nonblocking_comparison(self, comparison_env_setup, capsys):
|
||||
"""Compare concurrency ratios between blocking and non-blocking implementations."""
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def blocking_impl() -> str:
|
||||
time.sleep(0.002) # 2ms blocking
|
||||
return "blocking"
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def nonblocking_impl() -> str:
|
||||
await asyncio.sleep(0.002) # 2ms non-blocking
|
||||
return "nonblocking"
|
||||
|
||||
# Run blocking version
|
||||
await blocking_impl()
|
||||
blocking_output = capsys.readouterr().out
|
||||
|
||||
# Run non-blocking version
|
||||
await nonblocking_impl()
|
||||
nonblocking_output = capsys.readouterr().out
|
||||
|
||||
# Parse blocking metrics
|
||||
blocking_line = [l for l in blocking_output.split("\n") if "!@######CONC:" in l][0]
|
||||
blocking_parts = blocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
blocking_seq = int(blocking_parts[5])
|
||||
blocking_conc = int(blocking_parts[6])
|
||||
blocking_ratio = blocking_seq / blocking_conc if blocking_conc > 0 else 1.0
|
||||
|
||||
# Parse non-blocking metrics
|
||||
nonblocking_line = [l for l in nonblocking_output.split("\n") if "!@######CONC:" in l][0]
|
||||
nonblocking_parts = nonblocking_line.replace("!@######CONC:", "").replace("######@!", "").split(":")
|
||||
nonblocking_seq = int(nonblocking_parts[5])
|
||||
nonblocking_conc = int(nonblocking_parts[6])
|
||||
nonblocking_ratio = nonblocking_seq / nonblocking_conc if nonblocking_conc > 0 else 1.0
|
||||
|
||||
# Non-blocking should have significantly higher concurrency ratio
|
||||
assert nonblocking_ratio > blocking_ratio, (
|
||||
f"Non-blocking ratio ({nonblocking_ratio:.2f}) should be greater than "
|
||||
f"blocking ratio ({blocking_ratio:.2f})"
|
||||
)
|
||||
|
||||
# The difference should be substantial (non-blocking should be at least 2x better)
|
||||
ratio_improvement = nonblocking_ratio / blocking_ratio if blocking_ratio > 0 else 0
|
||||
assert ratio_improvement > 2.0, (
|
||||
f"Non-blocking should show >2x improvement in concurrency ratio, got {ratio_improvement:.2f}x"
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -2119,7 +2119,6 @@ print("Hello world")
|
|||
expected_code = """import numpy as np
|
||||
|
||||
a = 6
|
||||
|
||||
if 2<3:
|
||||
a=4
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1602,7 +1602,94 @@ def calculate_portfolio_metrics(
|
|||
# now the test should match and no diffs should be found
|
||||
assert len(diffs) == 0
|
||||
assert matched
|
||||
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
fto_file_path.unlink(missing_ok=True)
|
||||
fto_file_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_codeflash_capture_with_slots_class() -> None:
|
||||
"""Test that codeflash_capture works with classes that use __slots__ instead of __dict__."""
|
||||
test_code = """
|
||||
from code_to_optimize.tests.pytest.sample_code import SlotsClass
|
||||
import unittest
|
||||
|
||||
def test_slots_class():
|
||||
obj = SlotsClass(10, "test")
|
||||
assert obj.x == 10
|
||||
assert obj.y == "test"
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
tmp_dir_path = get_run_tmp_file(Path("test_return_values"))
|
||||
sample_code = f"""
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
class SlotsClass:
|
||||
__slots__ = ('x', 'y')
|
||||
|
||||
@codeflash_capture(function_name="SlotsClass.__init__", tmp_dir_path="{tmp_dir_path.as_posix()}", tests_root="{test_dir.as_posix()}")
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
"""
|
||||
test_file_name = "test_slots_class_temp.py"
|
||||
test_path = test_dir / test_file_name
|
||||
test_path_perf = test_dir / "test_slots_class_temp_perf.py"
|
||||
|
||||
tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/"
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
sample_code_path = test_dir / "sample_code.py"
|
||||
|
||||
try:
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_root,
|
||||
tests_project_rootdir=project_root_path,
|
||||
project_root_path=project_root_path,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="__init__",
|
||||
file_path=sample_code_path,
|
||||
parents=[FunctionParent(name="SlotsClass", type="ClassDef")],
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
benchmarking_file_path=test_path_perf,
|
||||
)
|
||||
]
|
||||
)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
|
||||
# Test should pass and capture the slots values
|
||||
assert len(test_results) == 1
|
||||
assert test_results[0].did_pass
|
||||
# The return value should contain the slot values
|
||||
assert test_results[0].return_value[0]["x"] == 10
|
||||
assert test_results[0].return_value[0]["y"] == "test"
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
sample_code_path.unlink(missing_ok=True)
|
||||
|
|
@ -17,7 +17,14 @@ import pytest
|
|||
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
from codeflash.verification.comparator import comparator, _extract_exception_from_message, _get_wrapped_exception
|
||||
from codeflash.verification.comparator import (
|
||||
PYTEST_TEMP_PATH_PATTERN,
|
||||
_extract_exception_from_message,
|
||||
_get_wrapped_exception,
|
||||
_is_temp_path,
|
||||
_normalize_temp_path,
|
||||
comparator,
|
||||
)
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
|
||||
|
|
@ -2309,6 +2316,77 @@ def test_dict_views() -> None:
|
|||
assert not comparator(d.items(), [("a", 1), ("b", 2)])
|
||||
|
||||
|
||||
def test_mappingproxy() -> None:
|
||||
"""Test comparator support for types.MappingProxyType (read-only dict view)."""
|
||||
import types
|
||||
|
||||
# Basic equality
|
||||
mp1 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
|
||||
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
|
||||
assert comparator(mp1, mp2)
|
||||
|
||||
# Different values
|
||||
mp3 = types.MappingProxyType({"a": 1, "b": 2, "c": 4})
|
||||
assert not comparator(mp1, mp3)
|
||||
|
||||
# Different keys
|
||||
mp4 = types.MappingProxyType({"a": 1, "b": 2, "d": 3})
|
||||
assert not comparator(mp1, mp4)
|
||||
|
||||
# Different length
|
||||
mp5 = types.MappingProxyType({"a": 1, "b": 2})
|
||||
assert not comparator(mp1, mp5)
|
||||
|
||||
# Order doesn't matter (like dict)
|
||||
mp6 = types.MappingProxyType({"c": 3, "a": 1, "b": 2})
|
||||
assert comparator(mp1, mp6)
|
||||
|
||||
# Empty mappingproxy
|
||||
empty1 = types.MappingProxyType({})
|
||||
empty2 = types.MappingProxyType({})
|
||||
assert comparator(empty1, empty2)
|
||||
|
||||
# Nested values
|
||||
nested1 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
|
||||
nested2 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
|
||||
nested3 = types.MappingProxyType({"a": [1, 2, 4], "b": {"x": 1}})
|
||||
assert comparator(nested1, nested2)
|
||||
assert not comparator(nested1, nested3)
|
||||
|
||||
# mappingproxy is not equal to dict (different types)
|
||||
d = {"a": 1, "b": 2}
|
||||
mp = types.MappingProxyType({"a": 1, "b": 2})
|
||||
assert not comparator(mp, d)
|
||||
assert not comparator(d, mp)
|
||||
|
||||
# Verify class __dict__ is indeed a mappingproxy
|
||||
class MyClass:
|
||||
x = 1
|
||||
y = 2
|
||||
|
||||
assert isinstance(MyClass.__dict__, types.MappingProxyType)
|
||||
|
||||
|
||||
def test_mappingproxy_superset() -> None:
|
||||
"""Test comparator superset_obj support for mappingproxy."""
|
||||
import types
|
||||
|
||||
mp1 = types.MappingProxyType({"a": 1, "b": 2})
|
||||
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
|
||||
|
||||
# mp2 is a superset of mp1
|
||||
assert comparator(mp1, mp2, superset_obj=True)
|
||||
# mp1 is not a superset of mp2
|
||||
assert not comparator(mp2, mp1, superset_obj=True)
|
||||
|
||||
# Same mappingproxy with superset_obj=True
|
||||
assert comparator(mp1, mp1, superset_obj=True)
|
||||
|
||||
# Different values even with superset
|
||||
mp3 = types.MappingProxyType({"a": 1, "b": 99, "c": 3})
|
||||
assert not comparator(mp1, mp3, superset_obj=True)
|
||||
|
||||
|
||||
def test_tensorflow_tensor() -> None:
|
||||
"""Test comparator support for TensorFlow Tensor objects."""
|
||||
try:
|
||||
|
|
@ -2911,16 +2989,378 @@ def test_numpy_dtypes() -> None:
|
|||
assert not comparator(dtypes.Int32DType(), np.dtype('float32'))
|
||||
|
||||
|
||||
def test_numpy_extended_precision_types() -> None:
|
||||
"""Test comparator for numpy extended precision types like clongdouble."""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
pytest.skip("numpy not available")
|
||||
|
||||
# Test np.clongdouble (extended precision complex)
|
||||
c1 = np.clongdouble(1 + 2j)
|
||||
c2 = np.clongdouble(1 + 2j)
|
||||
c3 = np.clongdouble(1 + 3j)
|
||||
assert comparator(c1, c2)
|
||||
assert not comparator(c1, c3)
|
||||
|
||||
# Test np.longdouble (extended precision float)
|
||||
l1 = np.longdouble(1.5)
|
||||
l2 = np.longdouble(1.5)
|
||||
l3 = np.longdouble(2.5)
|
||||
assert comparator(l1, l2)
|
||||
assert not comparator(l1, l3)
|
||||
|
||||
# Test NaN handling for extended precision complex
|
||||
nan_c1 = np.clongdouble(complex(np.nan, 2))
|
||||
nan_c2 = np.clongdouble(complex(np.nan, 2))
|
||||
assert comparator(nan_c1, nan_c2)
|
||||
|
||||
# Test NaN handling for extended precision float
|
||||
nan_l1 = np.longdouble(np.nan)
|
||||
nan_l2 = np.longdouble(np.nan)
|
||||
assert comparator(nan_l1, nan_l2)
|
||||
|
||||
|
||||
def test_numpy_typing_types() -> None:
|
||||
"""Test comparator for numpy.typing types like NDArray type aliases."""
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
except ImportError:
|
||||
pytest.skip("numpy or numpy.typing not available")
|
||||
|
||||
# Test NDArray type alias comparisons
|
||||
arr_type1 = npt.NDArray[np.float64]
|
||||
arr_type2 = npt.NDArray[np.float64]
|
||||
arr_type3 = npt.NDArray[np.int64]
|
||||
assert comparator(arr_type1, arr_type2)
|
||||
assert not comparator(arr_type1, arr_type3)
|
||||
|
||||
# Test NBitBase (if it can be instantiated)
|
||||
try:
|
||||
nbit1 = npt.NBitBase()
|
||||
nbit2 = npt.NBitBase()
|
||||
# NBitBase instances with empty __dict__ should compare as equal
|
||||
assert comparator(nbit1, nbit2)
|
||||
# Also test with superset_obj=True
|
||||
assert comparator(nbit1, nbit2, superset_obj=True)
|
||||
except TypeError:
|
||||
# NBitBase may not be instantiable in all numpy versions
|
||||
pass
|
||||
|
||||
|
||||
def test_numpy_typing_superset_obj() -> None:
|
||||
"""Test comparator with superset_obj=True for numpy types."""
|
||||
try:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
except ImportError:
|
||||
pytest.skip("numpy or numpy.typing not available")
|
||||
|
||||
# Test numpy arrays with object dtype containing dicts (superset scenario)
|
||||
a1 = np.array([{'a': 1}], dtype=object)
|
||||
a2 = np.array([{'a': 1, 'b': 2}], dtype=object) # superset
|
||||
assert comparator(a1, a2, superset_obj=True)
|
||||
assert not comparator(a1, a2, superset_obj=False)
|
||||
|
||||
# Test extended precision types with superset_obj=True
|
||||
c1 = np.clongdouble(1 + 2j)
|
||||
c2 = np.clongdouble(1 + 2j)
|
||||
assert comparator(c1, c2, superset_obj=True)
|
||||
|
||||
l1 = np.longdouble(1.5)
|
||||
l2 = np.longdouble(1.5)
|
||||
assert comparator(l1, l2, superset_obj=True)
|
||||
|
||||
# Test NDArray type alias with superset_obj=True
|
||||
arr_type1 = npt.NDArray[np.float64]
|
||||
arr_type2 = npt.NDArray[np.float64]
|
||||
assert comparator(arr_type1, arr_type2, superset_obj=True)
|
||||
|
||||
# Test numpy structured arrays (np.void) with superset_obj=True
|
||||
dt = np.dtype([('name', 'S10'), ('age', np.int32)])
|
||||
a_struct = np.array([('Alice', 25)], dtype=dt)
|
||||
b_struct = np.array([('Alice', 25)], dtype=dt)
|
||||
assert comparator(a_struct[0], b_struct[0], superset_obj=True)
|
||||
|
||||
# Test numpy random generators with superset_obj=True
|
||||
rng1 = np.random.default_rng(seed=42)
|
||||
rng2 = np.random.default_rng(seed=42)
|
||||
assert comparator(rng1, rng2, superset_obj=True)
|
||||
|
||||
rs1 = np.random.RandomState(seed=42)
|
||||
rs2 = np.random.RandomState(seed=42)
|
||||
assert comparator(rs1, rs2, superset_obj=True)
|
||||
def test_numba_typed_list() -> None:
|
||||
"""Test comparator for numba.typed.List."""
|
||||
try:
|
||||
import numba
|
||||
from numba.typed import List as NumbaList
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test equal lists
|
||||
a = NumbaList([1, 2, 3])
|
||||
b = NumbaList([1, 2, 3])
|
||||
assert comparator(a, b)
|
||||
|
||||
# Test different values
|
||||
c = NumbaList([1, 2, 4])
|
||||
assert not comparator(a, c)
|
||||
|
||||
# Test different lengths
|
||||
d = NumbaList([1, 2, 3, 4])
|
||||
assert not comparator(a, d)
|
||||
|
||||
# Test empty lists
|
||||
e = NumbaList.empty_list(item_type=numba.int64)
|
||||
f = NumbaList.empty_list(item_type=numba.int64)
|
||||
assert comparator(e, f)
|
||||
|
||||
# Test nested values (floats)
|
||||
g = NumbaList([1.0, 2.0, 3.0])
|
||||
h = NumbaList([1.0, 2.0, 3.0])
|
||||
assert comparator(g, h)
|
||||
|
||||
i = NumbaList([1.0, 2.0, 4.0])
|
||||
assert not comparator(g, i)
|
||||
|
||||
|
||||
def test_numba_typed_dict() -> None:
|
||||
"""Test comparator for numba.typed.Dict."""
|
||||
try:
|
||||
import numba
|
||||
from numba.typed import Dict as NumbaDict
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test equal dicts
|
||||
a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
a["x"] = 1
|
||||
a["y"] = 2
|
||||
|
||||
b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
b["x"] = 1
|
||||
b["y"] = 2
|
||||
assert comparator(a, b)
|
||||
|
||||
# Test different values
|
||||
c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
c["x"] = 1
|
||||
c["y"] = 3
|
||||
assert not comparator(a, c)
|
||||
|
||||
# Test different keys
|
||||
d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
d["x"] = 1
|
||||
d["z"] = 2
|
||||
assert not comparator(a, d)
|
||||
|
||||
# Test different lengths
|
||||
e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
e["x"] = 1
|
||||
assert not comparator(a, e)
|
||||
|
||||
# Test empty dicts
|
||||
f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
assert comparator(f, g)
|
||||
|
||||
|
||||
def test_numba_types() -> None:
|
||||
"""Test comparator for numba type objects."""
|
||||
try:
|
||||
import numba
|
||||
from numba import types
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test basic numeric types from numba module
|
||||
assert comparator(numba.int64, numba.int64)
|
||||
assert comparator(numba.float64, numba.float64)
|
||||
assert comparator(numba.int32, numba.int32)
|
||||
assert comparator(numba.float32, numba.float32)
|
||||
|
||||
# Test basic numeric types from numba.types module
|
||||
assert comparator(types.int64, types.int64)
|
||||
assert comparator(types.float64, types.float64)
|
||||
assert comparator(types.int8, types.int8)
|
||||
assert comparator(types.int16, types.int16)
|
||||
assert comparator(types.uint8, types.uint8)
|
||||
assert comparator(types.uint16, types.uint16)
|
||||
assert comparator(types.uint32, types.uint32)
|
||||
assert comparator(types.uint64, types.uint64)
|
||||
assert comparator(types.complex64, types.complex64)
|
||||
assert comparator(types.complex128, types.complex128)
|
||||
|
||||
# Test different types
|
||||
assert not comparator(numba.int64, numba.float64)
|
||||
assert not comparator(numba.int32, numba.int64)
|
||||
assert not comparator(numba.float32, numba.float64)
|
||||
assert not comparator(types.int8, types.int16)
|
||||
assert not comparator(types.uint32, types.int32)
|
||||
assert not comparator(types.complex64, types.complex128)
|
||||
|
||||
# Test boolean type
|
||||
assert comparator(numba.boolean, numba.boolean)
|
||||
assert comparator(types.boolean, types.boolean)
|
||||
assert not comparator(numba.boolean, numba.int64)
|
||||
|
||||
# Test special types
|
||||
assert comparator(types.none, types.none)
|
||||
assert comparator(types.void, types.void)
|
||||
assert comparator(types.pyobject, types.pyobject)
|
||||
assert comparator(types.unicode_type, types.unicode_type)
|
||||
# Note: types.none and types.void are the same object in numba
|
||||
assert comparator(types.none, types.void)
|
||||
assert not comparator(types.unicode_type, types.pyobject)
|
||||
assert not comparator(types.none, types.int64)
|
||||
|
||||
# Test array types
|
||||
arr_type1 = types.Array(numba.float64, 1, 'C')
|
||||
arr_type2 = types.Array(numba.float64, 1, 'C')
|
||||
arr_type3 = types.Array(numba.float64, 2, 'C')
|
||||
arr_type4 = types.Array(numba.int64, 1, 'C')
|
||||
arr_type5 = types.Array(numba.float64, 1, 'F') # Fortran order
|
||||
|
||||
assert comparator(arr_type1, arr_type2)
|
||||
assert not comparator(arr_type1, arr_type3) # different ndim
|
||||
assert not comparator(arr_type1, arr_type4) # different dtype
|
||||
assert not comparator(arr_type1, arr_type5) # different layout
|
||||
|
||||
# Test tuple types
|
||||
tuple_type1 = types.UniTuple(types.int64, 3)
|
||||
tuple_type2 = types.UniTuple(types.int64, 3)
|
||||
tuple_type3 = types.UniTuple(types.int64, 4)
|
||||
tuple_type4 = types.UniTuple(types.float64, 3)
|
||||
|
||||
assert comparator(tuple_type1, tuple_type2)
|
||||
assert not comparator(tuple_type1, tuple_type3) # different count
|
||||
assert not comparator(tuple_type1, tuple_type4) # different dtype
|
||||
|
||||
# Test heterogeneous tuple types
|
||||
hetero_tuple1 = types.Tuple([types.int64, types.float64])
|
||||
hetero_tuple2 = types.Tuple([types.int64, types.float64])
|
||||
hetero_tuple3 = types.Tuple([types.int64, types.int64])
|
||||
|
||||
assert comparator(hetero_tuple1, hetero_tuple2)
|
||||
assert not comparator(hetero_tuple1, hetero_tuple3)
|
||||
|
||||
# Test ListType and DictType
|
||||
list_type1 = types.ListType(types.int64)
|
||||
list_type2 = types.ListType(types.int64)
|
||||
list_type3 = types.ListType(types.float64)
|
||||
|
||||
assert comparator(list_type1, list_type2)
|
||||
assert not comparator(list_type1, list_type3)
|
||||
|
||||
dict_type1 = types.DictType(types.unicode_type, types.int64)
|
||||
dict_type2 = types.DictType(types.unicode_type, types.int64)
|
||||
dict_type3 = types.DictType(types.unicode_type, types.float64)
|
||||
dict_type4 = types.DictType(types.int64, types.int64)
|
||||
|
||||
assert comparator(dict_type1, dict_type2)
|
||||
assert not comparator(dict_type1, dict_type3) # different value type
|
||||
assert not comparator(dict_type1, dict_type4) # different key type
|
||||
|
||||
|
||||
def test_numba_jit_functions() -> None:
|
||||
"""Test comparator for numba JIT-compiled functions."""
|
||||
try:
|
||||
from numba import jit
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
@jit(nopython=True)
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
@jit(nopython=True)
|
||||
def add2(x, y):
|
||||
return x + y
|
||||
|
||||
@jit(nopython=True)
|
||||
def multiply(x, y):
|
||||
return x * y
|
||||
|
||||
# Compile the functions by calling them
|
||||
add(1, 2)
|
||||
add2(1, 2)
|
||||
multiply(1, 2)
|
||||
|
||||
# Same function should compare equal to itself
|
||||
assert comparator(add, add)
|
||||
|
||||
# Different functions (even with same code) should not compare equal
|
||||
# since they are distinct function objects
|
||||
assert not comparator(add, add2)
|
||||
|
||||
# Different functions with different code should not compare equal
|
||||
assert not comparator(add, multiply)
|
||||
|
||||
|
||||
def test_numba_superset_obj() -> None:
|
||||
"""Test comparator for numba types with superset_obj=True."""
|
||||
try:
|
||||
import numba
|
||||
from numba.typed import Dict as NumbaDict
|
||||
from numba.typed import List as NumbaList
|
||||
except ImportError:
|
||||
pytest.skip("numba not available")
|
||||
|
||||
# Test NumbaDict with superset_obj=True
|
||||
orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
orig_dict["x"] = 1
|
||||
orig_dict["y"] = 2
|
||||
|
||||
# New dict with same keys - should pass
|
||||
new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_same["x"] = 1
|
||||
new_dict_same["y"] = 2
|
||||
assert comparator(orig_dict, new_dict_same, superset_obj=True)
|
||||
|
||||
# New dict with extra keys - should pass with superset_obj=True
|
||||
new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_superset["x"] = 1
|
||||
new_dict_superset["y"] = 2
|
||||
new_dict_superset["z"] = 3
|
||||
assert comparator(orig_dict, new_dict_superset, superset_obj=True)
|
||||
# But should fail with superset_obj=False
|
||||
assert not comparator(orig_dict, new_dict_superset, superset_obj=False)
|
||||
|
||||
# New dict missing keys - should fail even with superset_obj=True
|
||||
new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_subset["x"] = 1
|
||||
assert not comparator(orig_dict, new_dict_subset, superset_obj=True)
|
||||
|
||||
# New dict with different values - should fail
|
||||
new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
new_dict_diff["x"] = 1
|
||||
new_dict_diff["y"] = 99
|
||||
assert not comparator(orig_dict, new_dict_diff, superset_obj=True)
|
||||
|
||||
# Test NumbaList with superset_obj=True (lists don't support superset semantics)
|
||||
orig_list = NumbaList([1, 2, 3])
|
||||
new_list_same = NumbaList([1, 2, 3])
|
||||
new_list_longer = NumbaList([1, 2, 3, 4])
|
||||
|
||||
assert comparator(orig_list, new_list_same, superset_obj=True)
|
||||
# Lists must have same length regardless of superset_obj
|
||||
assert not comparator(orig_list, new_list_longer, superset_obj=True)
|
||||
|
||||
# Test empty dict with superset_obj=True
|
||||
empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
||||
non_empty_new["a"] = 1
|
||||
# Empty orig should match any superset
|
||||
assert comparator(empty_orig, non_empty_new, superset_obj=True)
|
||||
assert not comparator(empty_orig, non_empty_new, superset_obj=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for pytest temp path normalization (lines 28-69 in comparator.py)
|
||||
# =============================================================================
|
||||
|
||||
from codeflash.verification.comparator import (
|
||||
PYTEST_TEMP_PATH_PATTERN,
|
||||
_is_temp_path,
|
||||
_normalize_temp_path,
|
||||
)
|
||||
|
||||
|
||||
class TestIsTempPath:
|
||||
"""Tests for the _is_temp_path() function."""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from unittest.mock import Mock
|
|||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
ConcurrencyMetrics,
|
||||
CoverageData,
|
||||
CoverageStatus,
|
||||
FunctionCoverage,
|
||||
|
|
@ -15,12 +16,14 @@ from codeflash.models.models import (
|
|||
TestType,
|
||||
)
|
||||
from codeflash.result.critic import (
|
||||
concurrency_gain,
|
||||
coverage_critic,
|
||||
performance_gain,
|
||||
quantity_of_tests_critic,
|
||||
speedup_critic,
|
||||
throughput_gain,
|
||||
)
|
||||
from codeflash.verification.parse_test_output import parse_concurrency_metrics
|
||||
|
||||
|
||||
def test_performance_gain() -> None:
|
||||
|
|
@ -569,3 +572,238 @@ def test_speedup_critic_with_async_throughput() -> None:
|
|||
best_throughput_until_now=None,
|
||||
disable_gh_action_noise=True
|
||||
)
|
||||
|
||||
|
||||
def test_concurrency_gain() -> None:
|
||||
"""Test concurrency_gain calculation."""
|
||||
# Test basic concurrency improvement (blocking -> non-blocking)
|
||||
original = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000, # 10ms
|
||||
concurrent_time_ns=10_000_000, # 10ms (no speedup - blocking)
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0, # sequential/concurrent = 1.0
|
||||
)
|
||||
optimized = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000, # 10ms
|
||||
concurrent_time_ns=1_000_000, # 1ms (10x speedup - non-blocking)
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=10.0, # sequential/concurrent = 10.0
|
||||
)
|
||||
# 900% improvement: (10 - 1) / 1 = 9.0
|
||||
assert concurrency_gain(original, optimized) == 9.0
|
||||
|
||||
# Test no improvement
|
||||
same = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0,
|
||||
)
|
||||
assert concurrency_gain(original, same) == 0.0
|
||||
|
||||
# Test slight improvement
|
||||
slightly_better = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=8_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.25,
|
||||
)
|
||||
# 25% improvement: (1.25 - 1.0) / 1.0 = 0.25
|
||||
assert concurrency_gain(original, slightly_better) == 0.25
|
||||
|
||||
# Test zero original ratio (edge case)
|
||||
zero_ratio = ConcurrencyMetrics(
|
||||
sequential_time_ns=0,
|
||||
concurrent_time_ns=1_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=0.0,
|
||||
)
|
||||
assert concurrency_gain(zero_ratio, optimized) == 0.0
|
||||
|
||||
|
||||
def test_speedup_critic_with_concurrency_metrics() -> None:
|
||||
"""Test speedup_critic with concurrency metrics evaluation."""
|
||||
original_code_runtime = 10000 # 10 microseconds
|
||||
original_async_throughput = 100
|
||||
|
||||
# Original concurrency metrics (blocking code - ratio ~= 1.0)
|
||||
original_concurrency = ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0,
|
||||
)
|
||||
|
||||
# Test case 1: Concurrency improves significantly (blocking -> non-blocking)
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=10000, # Same runtime
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=10000,
|
||||
async_throughput=100, # Same throughput
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=1_000_000, # 10x faster concurrent execution
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=10.0, # 900% improvement
|
||||
),
|
||||
)
|
||||
|
||||
# Should pass due to concurrency improvement even though runtime/throughput unchanged
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 2: No concurrency improvement (should fall back to other metrics)
|
||||
candidate_result_no_conc = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=8000, # 20% runtime improvement
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=8000,
|
||||
async_throughput=100,
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=10_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.0, # No improvement
|
||||
),
|
||||
)
|
||||
|
||||
# Should pass due to runtime improvement
|
||||
assert speedup_critic(
|
||||
candidate_result=candidate_result_no_conc,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 3: Concurrency below threshold (20% required)
|
||||
candidate_result_below_threshold = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=10000, # Same runtime
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=10000,
|
||||
async_throughput=100, # Same throughput
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=9_000_000, # Only 11% improvement
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=1.11,
|
||||
),
|
||||
)
|
||||
|
||||
# Should fail - no metric improves enough
|
||||
assert not speedup_critic(
|
||||
candidate_result=candidate_result_below_threshold,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
# Test case 4: best_concurrency_ratio_until_now comparison
|
||||
candidate_result_good = OptimizedCandidateResult(
|
||||
max_loop_count=5,
|
||||
best_test_runtime=10000,
|
||||
behavior_test_results=TestResults(),
|
||||
benchmarking_test_results=TestResults(),
|
||||
optimization_candidate_index=0,
|
||||
total_candidate_timing=10000,
|
||||
async_throughput=100,
|
||||
concurrency_metrics=ConcurrencyMetrics(
|
||||
sequential_time_ns=10_000_000,
|
||||
concurrent_time_ns=2_000_000,
|
||||
concurrency_factor=10,
|
||||
concurrency_ratio=5.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Should fail when there's a better concurrency ratio already
|
||||
assert not speedup_critic(
|
||||
candidate_result=candidate_result_good,
|
||||
original_code_runtime=original_code_runtime,
|
||||
best_runtime_until_now=None,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_concurrency,
|
||||
best_concurrency_ratio_until_now=10.0, # Better ratio already exists
|
||||
disable_gh_action_noise=True,
|
||||
)
|
||||
|
||||
|
||||
def test_concurrency_ratio_display_formatting() -> None:
|
||||
orig_ratio = 0.05
|
||||
cand_ratio = 0.15
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
assert display_string == "Concurrency ratio: 0.05x → 0.15x (+200.0%)"
|
||||
|
||||
orig_ratio = 1.0
|
||||
cand_ratio = 10.0
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
assert display_string == "Concurrency ratio: 1.00x → 10.00x (+900.0%)"
|
||||
|
||||
orig_ratio = 0.01
|
||||
cand_ratio = 0.03
|
||||
conc_gain = ((cand_ratio - orig_ratio) / orig_ratio * 100) if orig_ratio > 0 else 0
|
||||
display_string = f"Concurrency ratio: {orig_ratio:.2f}x → {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
assert display_string == "Concurrency ratio: 0.01x → 0.03x (+200.0%)"
|
||||
|
||||
|
||||
def test_parse_concurrency_metrics() -> None:
|
||||
"""Test parse_concurrency_metrics function."""
|
||||
# Test with valid concurrency output
|
||||
stdout = (
|
||||
"!@######CONC:test_module:TestClass:test_func:my_function:0:10000000:1000000:10######@!\n"
|
||||
"!@######CONC:test_module:TestClass:test_func:my_function:1:10000000:1000000:10######@!\n"
|
||||
)
|
||||
test_results = TestResults(perf_stdout=stdout)
|
||||
|
||||
metrics = parse_concurrency_metrics(test_results, "my_function")
|
||||
assert metrics is not None
|
||||
assert metrics.sequential_time_ns == 10_000_000 # Average of both matches
|
||||
assert metrics.concurrent_time_ns == 1_000_000
|
||||
assert metrics.concurrency_factor == 10
|
||||
assert metrics.concurrency_ratio == 10.0 # 10000000 / 1000000
|
||||
|
||||
# Test with no matching function
|
||||
metrics_wrong_func = parse_concurrency_metrics(test_results, "other_function")
|
||||
assert metrics_wrong_func is None
|
||||
|
||||
# Test with empty stdout
|
||||
empty_results = TestResults(perf_stdout="")
|
||||
metrics_empty = parse_concurrency_metrics(empty_results, "my_function")
|
||||
assert metrics_empty is None
|
||||
|
||||
# Test with None stdout
|
||||
none_results = TestResults(perf_stdout=None)
|
||||
metrics_none = parse_concurrency_metrics(none_results, "my_function")
|
||||
assert metrics_none is None
|
||||
|
||||
# Test with no class name
|
||||
stdout_no_class = "!@######CONC:test_module::test_func:my_function:0:5000000:2500000:10######@!\n"
|
||||
test_results_no_class = TestResults(perf_stdout=stdout_no_class)
|
||||
metrics_no_class = parse_concurrency_metrics(test_results_no_class, "my_function")
|
||||
assert metrics_no_class is not None
|
||||
assert metrics_no_class.concurrency_ratio == 2.0 # 5000000 / 2500000
|
||||
|
|
|
|||
|
|
@ -218,8 +218,28 @@ def test_no_targets_found() -> None:
|
|||
def target(self):
|
||||
pass
|
||||
"""
|
||||
result = parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
class Inner:
|
||||
def target(self):
|
||||
pass
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_no_targets_found_raises_for_nonexistent() -> None:
|
||||
"""Test that ValueError is raised when the target function doesn't exist at all."""
|
||||
code = """
|
||||
class MyClass:
|
||||
def method(self):
|
||||
pass
|
||||
"""
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"MyClass.Inner.target"})
|
||||
parse_code_and_prune_cst(dedent(code),CodeContextType.READ_WRITABLE, {"NonExistent.target"})
|
||||
|
||||
|
||||
def test_module_var() -> None:
|
||||
|
|
|
|||
|
|
@ -131,6 +131,46 @@ async def async_function(x: int, y: int) -> int:
|
|||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_decorator_application_concurrency_mode(temp_dir):
|
||||
"""Test that CONCURRENCY mode applies the codeflash_concurrency_async decorator."""
|
||||
async_function_code = '''
|
||||
import asyncio
|
||||
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
expected_decorated_code = '''
|
||||
import asyncio
|
||||
|
||||
from codeflash.code_utils.codeflash_wrap_decorator import \\
|
||||
codeflash_concurrency_async
|
||||
|
||||
|
||||
@codeflash_concurrency_async
|
||||
async def async_function(x: int, y: int) -> int:
|
||||
"""Simple async function for testing."""
|
||||
await asyncio.sleep(0.01)
|
||||
return x * y
|
||||
'''
|
||||
|
||||
test_file = temp_dir / "test_async.py"
|
||||
test_file.write_text(async_function_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="async_function", file_path=test_file, parents=[], is_async=True
|
||||
)
|
||||
|
||||
decorator_added = add_async_decorator_to_function(test_file, func, TestingMode.CONCURRENCY)
|
||||
|
||||
assert decorator_added
|
||||
modified_code = test_file.read_text()
|
||||
assert modified_code.strip() == expected_decorated_code.strip()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="pending support for asyncio on windows")
|
||||
def test_async_class_method_decorator_application(temp_dir):
|
||||
async_class_code = '''
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ def _get_string_usage(text: str) -> Usage:
|
|||
|
||||
helper_file.unlink(missing_ok=True)
|
||||
main_file.unlink(missing_ok=True)
|
||||
|
||||
|
||||
expected_helper = """import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
|
|
|||
|
|
@ -481,3 +481,86 @@ def unused_function():
|
|||
qualified_functions = {"get_platform_info", "get_loop_result"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_enum_attribute_access_dependency() -> None:
|
||||
"""Test that enum/class attribute access like MessageKind.VALUE is tracked as a dependency."""
|
||||
code = """
|
||||
from enum import Enum
|
||||
|
||||
class MessageKind(Enum):
|
||||
VALUE = "value"
|
||||
OTHER = "other"
|
||||
|
||||
class UnusedEnum(Enum):
|
||||
UNUSED = "unused"
|
||||
|
||||
UNUSED_VAR = 123
|
||||
|
||||
def process_message(kind):
|
||||
match kind:
|
||||
case MessageKind.VALUE:
|
||||
return "got value"
|
||||
case MessageKind.OTHER:
|
||||
return "got other"
|
||||
return "unknown"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
from enum import Enum
|
||||
|
||||
class MessageKind(Enum):
|
||||
VALUE = "value"
|
||||
OTHER = "other"
|
||||
|
||||
class UnusedEnum(Enum):
|
||||
UNUSED = "unused"
|
||||
|
||||
def process_message(kind):
|
||||
match kind:
|
||||
case MessageKind.VALUE:
|
||||
return "got value"
|
||||
case MessageKind.OTHER:
|
||||
return "got other"
|
||||
return "unknown"
|
||||
"""
|
||||
|
||||
qualified_functions = {"process_message"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# MessageKind should be preserved because process_message uses MessageKind.VALUE
|
||||
assert "class MessageKind" in result
|
||||
# UNUSED_VAR should be removed
|
||||
assert "UNUSED_VAR" not in result
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_attribute_access_does_not_track_attr_name() -> None:
|
||||
"""Test that self.x attribute access doesn't track 'x' as a dependency on module-level x."""
|
||||
code = """
|
||||
x = "module_level_x"
|
||||
UNUSED_VAR = "unused"
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x'
|
||||
|
||||
def get_x(self):
|
||||
return self.x # This 'x' is also an attribute access
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1 # This 'x' is an attribute, not a reference to module-level 'x'
|
||||
|
||||
def get_x(self):
|
||||
return self.x # This 'x' is also an attribute access
|
||||
"""
|
||||
|
||||
qualified_functions = {"MyClass.get_x", "MyClass.__init__"}
|
||||
result = remove_unused_definitions_by_function_names(code, qualified_functions)
|
||||
# Module-level x should NOT be kept (self.x doesn't reference it)
|
||||
assert 'x = "module_level_x"' not in result
|
||||
# UNUSED_VAR should also be removed
|
||||
assert "UNUSED_VAR" not in result
|
||||
assert result.strip() == expected.strip()
|
||||
|
|
|
|||
11
uv.lock
11
uv.lock
|
|
@ -925,7 +925,7 @@ name = "exceptiongroup"
|
|||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
|
|
@ -5473,11 +5473,14 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "wheel"
|
||||
version = "0.45.1"
|
||||
version = "0.46.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545, upload-time = "2024-11-23T00:18:23.513Z" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9b/92/3a64fa9639b8e290fe8630d8067a66f7c5510845c6d73686ad880c9b04d9/wheel-0.46.2.tar.gz", hash = "sha256:3d79e48fde9847618a5a181f3cc35764c349c752e2fe911e65fa17faab9809b0", size = 60274, upload-time = "2026-01-21T23:55:25.838Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494, upload-time = "2024-11-23T00:18:21.207Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/2c/5e079cefe955ae58e5a052fe037c850ce493eb7269dedeb960237e78fb0f/wheel-0.46.2-py3-none-any.whl", hash = "sha256:33ae60725d69eaa249bc1982e739943c23b34b58d51f1cb6253453773aca6e65", size = 29971, upload-time = "2026-01-21T23:55:24.447Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Reference in a new issue