Ready to review

This commit is contained in:
aseembits93 2025-06-06 13:18:58 -07:00
commit 0be74c4d52
26 changed files with 2696 additions and 77 deletions

View file

@ -25,6 +25,7 @@ class CodeflashTrace:
"""Set up the database connection for direct writing.
Args:
----
trace_path: Path to the trace database file
"""
@ -52,6 +53,7 @@ class CodeflashTrace:
"""Write function call data directly to the database.
Args:
----
data: List of function call data tuples to write
"""
@ -94,9 +96,11 @@ class CodeflashTrace:
"""Use as a decorator to trace function execution.
Args:
----
func: The function to be decorated
Returns:
-------
The wrapped function
"""

View file

@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
"""Add codeflash_trace to a function.
Args:
----
code: The source code as a string
functions_to_optimize: List of FunctionToOptimize instances containing function details
Returns:
-------
The modified source code as a string
"""

View file

@ -74,9 +74,11 @@ class CodeFlashBenchmarkPlugin:
"""Process the trace file and extract timing data for all functions.
Args:
----
trace_path: Path to the trace file
Returns:
-------
A nested dictionary where:
- Outer keys are module_name.qualified_name (module.class.function)
- Inner keys are of type BenchmarkKey
@ -132,9 +134,11 @@ class CodeFlashBenchmarkPlugin:
"""Extract total benchmark timings from trace files.
Args:
----
trace_path: Path to the trace file
Returns:
-------
A dictionary mapping where:
- Keys are of type BenchmarkKey
- Values are total benchmark timing in milliseconds (with overhead subtracted)

View file

@ -55,12 +55,14 @@ def create_trace_replay_test_code(
"""Create a replay test for functions based on trace data.
Args:
----
trace_file: Path to the SQLite database file
functions_data: List of dictionaries with function info extracted from DB
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include in the test
Returns:
-------
A string containing the test code
"""
@ -218,12 +220,14 @@ def generate_replay_test(
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
Args:
----
trace_file_path: Path to the SQLite database file
output_dir: Directory to write the generated tests (if None, only returns the code)
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include per function
Returns:
-------
Dictionary mapping benchmark names to generated test code
"""

View file

@ -83,11 +83,13 @@ def process_benchmark_data(
"""Process benchmark data and generate detailed benchmark information.
Args:
----
replay_performance_gain: The performance gain from replay
fto_benchmark_timings: Function to optimize benchmark timings
total_benchmark_timings: Total benchmark timings
Returns:
-------
ProcessedBenchmarkInfo containing processed benchmark details
"""

View file

@ -211,7 +211,7 @@ def collect_setup_info() -> SetupInfo:
# Discover test directory
default_tests_subdir = "tests"
create_for_me_option = f"okay, create a tests{os.pathsep} directory for me!"
test_subdir_options = valid_subdirs
test_subdir_options = [sub_dir for sub_dir in valid_subdirs if sub_dir != module_root]
if "tests" not in valid_subdirs:
test_subdir_options.append(create_for_me_option)
custom_dir_option = "enter a custom directory…"
@ -240,7 +240,16 @@ def collect_setup_info() -> SetupInfo:
apologize_and_exit()
else:
tests_root = Path(curdir) / Path(cast("str", tests_root_answer))
tests_root = tests_root.relative_to(curdir)
resolved_module_root = (Path(curdir) / Path(module_root)).resolve()
resolved_tests_root = (Path(curdir) / Path(tests_root)).resolve()
if resolved_module_root == resolved_tests_root:
logger.warning(
"It looks like your tests root is the same as your module root. This is not recommended and can lead to unexpected behavior."
)
ph("cli-tests-root-provided")
# Autodiscover test framework

View file

@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
],
force=True,
)
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
logging.info("Verbose DEBUG logging enabled")
else:
logging.info("Logging level set to INFO") # noqa: LOG015
logging.info("Logging level set to INFO")
console.rule()

View file

@ -47,6 +47,7 @@ class CodeflashRunCheckpoint:
"""Add a function to the checkpoint after it has been processed.
Args:
----
function_fully_qualified_name: The fully qualified name of the function
status: Status of optimization (e.g., "optimized", "failed", "skipped")
additional_info: Any additional information to store about the function
@ -104,7 +105,8 @@ class CodeflashRunCheckpoint:
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
"""Get information about all processed functions, regardless of status.
Returns:
Returns
-------
Dictionary mapping function names to their processing information
"""

View file

@ -1,21 +1,47 @@
import os
import sys
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from platformdirs import user_config_dir
# os-independent newline
# important for any user-facing output or files we write
# make sure to use this in f-strings e.g. f"some string{LF}"
# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string
LF: str = os.linesep
if TYPE_CHECKING:
codeflash_temp_dir: Path
codeflash_cache_dir: Path
codeflash_cache_db: Path
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
class Compat:
# os-independent newline
LF: str = os.linesep
IS_POSIX = os.name != "nt"
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
IS_POSIX: bool = os.name != "nt"
@property
def codeflash_cache_dir(self) -> Path:
return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
@property
def codeflash_temp_dir(self) -> Path:
temp_dir = Path(tempfile.gettempdir()) / "codeflash"
if not temp_dir.exists():
temp_dir.mkdir(parents=True, exist_ok=True)
return temp_dir
@property
def codeflash_cache_db(self) -> Path:
return self.codeflash_cache_dir / "codeflash_cache.db"
codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
_compat = Compat()
codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db"
codeflash_temp_dir = _compat.codeflash_temp_dir
codeflash_cache_dir = _compat.codeflash_cache_dir
codeflash_cache_db = _compat.codeflash_cache_db
LF = _compat.LF
SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE
IS_POSIX = _compat.IS_POSIX

View file

@ -0,0 +1,141 @@
import re
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)
match = function_pattern.search(generated_test.generated_original_test_source)
if match is None or "@pytest.mark.parametrize" in match.group(0):
continue
generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)
new_generated_tests.append(generated_test)
return GeneratedTestsList(generated_tests=new_generated_tests)
def add_runtime_comments_to_generated_tests(
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
# Create dictionaries for fast lookup of runtime data
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(self) -> None:
self.in_test_function = False
self.current_test_name: str | None = None
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if node.name.value.startswith("test_"):
self.in_test_function = True
self.current_test_name = node.name.value
else:
self.in_test_function = False
self.current_test_name = None
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if original_node.name.value.startswith("test_"):
self.in_test_function = False
self.current_test_name = None
return updated_node
def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.SimpleStatementLine:
if not self.in_test_function or not self.current_test_name:
return updated_node
# Look for assignment statements that assign to codeflash_output
# Handle both single statements and multiple statements on one line
codeflash_assignment_found = False
for stmt in updated_node.body:
if isinstance(stmt, cst.Assign) and (
len(stmt.targets) == 1
and isinstance(stmt.targets[0].target, cst.Name)
and stmt.targets[0].target.value == "codeflash_output"
):
codeflash_assignment_found = True
break
if codeflash_assignment_found:
# Find matching test cases by looking for this test function name in the test results
matching_original_times = []
matching_optimized_times = []
for invocation_id, runtimes in original_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
matching_original_times.extend(runtimes)
for invocation_id, runtimes in optimized_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
matching_optimized_times.extend(runtimes)
if matching_original_times and matching_optimized_times:
original_time = min(matching_original_times)
optimized_time = min(matching_optimized_times)
# Create the runtime comment
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
# Add comment to the trailing whitespace
new_trailing_whitespace = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
return updated_node
# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
try:
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)
# Transform the tree to add runtime comments
transformer = RuntimeCommentTransformer()
modified_tree = tree.visit(transformer)
# Convert back to source code
modified_source = modified_tree.code
# Create a new GeneratedTests object with the modified source
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
# If parsing fails, keep the original test
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)
return GeneratedTestsList(generated_tests=modified_tests)

View file

@ -24,6 +24,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer):
"""Initialize the transformer.
Args:
----
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
decorator_name: The name of the decorator to add.
@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str,
"""Add a decorator to a function with the exact qualified name in the source code.
Args:
----
module: The Python source code as a CST module.
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
decorator_name: The name of the decorator to add.
Returns:
-------
The modified CST module.
"""

View file

@ -1,28 +0,0 @@
import re
from codeflash.models.models import GeneratedTestsList
def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)
match = function_pattern.search(generated_test.generated_original_test_source)
if match is None or "@pytest.mark.parametrize" in match.group(0):
continue
generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)
new_generated_tests.append(generated_test)
return GeneratedTestsList(generated_tests=new_generated_tests)

View file

@ -49,3 +49,40 @@ def humanize_runtime(time_in_ns: int) -> str:
runtime_human = runtime_human_parts[0]
return f"{runtime_human} {units}"
def format_time(nanoseconds: int) -> str:
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
# Inlined significant digit check: >= 3 digits if value >= 100
if nanoseconds < 1_000:
return f"{nanoseconds}ns"
if nanoseconds < 1_000_000:
microseconds_int = nanoseconds // 1_000
if microseconds_int >= 100:
return f"{microseconds_int}μs"
microseconds = nanoseconds / 1_000
# Format with precision: 3 significant digits
if microseconds >= 100:
return f"{microseconds:.0f}μs"
if microseconds >= 10:
return f"{microseconds:.1f}μs"
return f"{microseconds:.2f}μs"
if nanoseconds < 1_000_000_000:
milliseconds_int = nanoseconds // 1_000_000
if milliseconds_int >= 100:
return f"{milliseconds_int}ms"
milliseconds = nanoseconds / 1_000_000
if milliseconds >= 100:
return f"{milliseconds:.0f}ms"
if milliseconds >= 10:
return f"{milliseconds:.1f}ms"
return f"{milliseconds:.2f}ms"
seconds_int = nanoseconds // 1_000_000_000
if seconds_int >= 100:
return f"{seconds_int}s"
seconds = nanoseconds / 1_000_000_000
if seconds >= 100:
return f"{seconds:.0f}s"
if seconds >= 10:
return f"{seconds:.1f}s"
return f"{seconds:.2f}s"

View file

@ -3,11 +3,9 @@ from __future__ import annotations
import os
from collections import defaultdict
from itertools import chain
from pathlib import Path # noqa: TC003
from typing import TYPE_CHECKING
import libcst as cst
from libcst import CSTNode # noqa: TC002
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
@ -24,7 +22,10 @@ from codeflash.models.models import (
from codeflash.optimization.function_context import belongs_to_function_qualified
if TYPE_CHECKING:
from pathlib import Path
from jedi.api.classes import Name
from libcst import CSTNode
def get_code_optimization_context(
@ -150,6 +151,7 @@ def extract_code_string_context_from_files(
imports, and combines them.
Args:
----
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
project_root_path: Root path of the project
@ -157,6 +159,7 @@ def extract_code_string_context_from_files(
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
Returns:
-------
CodeString containing the extracted code context with necessary imports
""" # noqa: D205
@ -257,6 +260,7 @@ def extract_code_markdown_context_from_files(
imports, and combines them into a structured markdown format.
Args:
----
helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers
helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions
project_root_path: Root path of the project
@ -264,6 +268,7 @@ def extract_code_markdown_context_from_files(
code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN)
Returns:
-------
CodeStringsMarkdown containing the extracted code context with necessary imports,
formatted for inclusion in markdown
@ -382,7 +387,7 @@ def get_function_to_optimize_as_function_source(
source_code=name.get_line_code(),
jedi_definition=name,
)
except Exception as e: # noqa: PERF203
except Exception as e:
logger.exception(f"Error while getting function source: {e}")
continue
raise ValueError(
@ -502,7 +507,8 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
Returns:
Returns
-------
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.
@ -586,7 +592,8 @@ def prune_cst_for_read_only_code( # noqa: PLR0911
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for read-only context.
Returns:
Returns
-------
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.
@ -690,7 +697,8 @@ def prune_cst_for_testgen_code( # noqa: PLR0911
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for testgen context.
Returns:
Returns
-------
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.

View file

@ -1,10 +1,17 @@
from __future__ import annotations
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeOptimizationContext, FunctionSource
@dataclass
class UsageInfo:
@ -311,10 +318,12 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
"""Recursively filter the node to remove unused definitions.
Args:
----
node: The CST node to process
definitions: Dictionary of definition info
Returns:
-------
(filtered_node, used_by_function):
filtered_node: The modified CST node or None if it should be removed
used_by_function: True if this node or any child is used by qualified functions
@ -450,6 +459,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
If a class is referenced by a qualified function, we keep the entire class.
Args:
----
code: The code to process
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
@ -480,3 +490,220 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
print(f" Used by qualified function: {info.used_by_qualified_function}")
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
print()
def revert_unused_helper_functions(
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
) -> None:
"""Revert unused helper functions back to their original definitions.
Args:
unused_helpers: List of unused helper functions to revert
original_helper_code: Dictionary mapping file paths to their original code
"""
if not unused_helpers:
return
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
# Group unused helpers by file path
unused_helpers_by_file = defaultdict(list)
for helper in unused_helpers:
unused_helpers_by_file[helper.file_path].append(helper)
# For each file, revert the unused helper functions to their original definitions
for file_path, helpers_in_file in unused_helpers_by_file.items():
if file_path in original_helper_code:
try:
# Read current file content
current_code = file_path.read_text(encoding="utf8")
# Get original code for this file
original_code = original_helper_code[file_path]
# Use the code replacer to selectively revert only the unused helper functions
helper_names = [helper.qualified_name for helper in helpers_in_file]
reverted_code = replace_function_definitions_in_module(
function_names=helper_names,
optimized_code=original_code, # Use original code as the "optimized" code to revert
module_abspath=file_path,
preexisting_objects=set(), # Empty set since we're reverting
project_root_path=project_root,
)
if reverted_code:
logger.debug(f"Reverted unused helpers in {file_path}: {', '.join(helper_names)}")
except Exception as e:
logger.error(f"Error reverting unused helpers in {file_path}: {e}")
def _analyze_imports_in_optimized_code(
optimized_ast: ast.AST, code_context: CodeOptimizationContext
) -> dict[str, set[str]]:
"""Analyze import statements in optimized code to map imported names to qualified helper names.
Args:
optimized_ast: The AST of the optimized code
code_context: The code optimization context containing helper functions
Returns:
Dictionary mapping imported names to sets of possible qualified helper names
"""
imported_names_map = defaultdict(set)
# Precompute a two-level dict: module_name -> func_name -> [helpers]
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
helpers_append = helpers_by_file_and_func.setdefault
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type
if jedi_type != "class":
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
file_entry = helpers_by_file_and_func[module_name]
if func_name in file_entry:
file_entry[func_name].append(helper)
else:
file_entry[func_name] = [helper]
helpers_by_file[module_name].append(helper)
# Optimize attribute lookups and method binding outside the loop
helpers_by_file_and_func_get = helpers_by_file_and_func.get
helpers_by_file_get = helpers_by_file.get
for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
module_name = node.module
if module_name:
file_entry = helpers_by_file_and_func_get(module_name, None)
if file_entry:
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name
helpers = file_entry.get(original_name, None)
if helpers:
for helper in helpers:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)
elif isinstance(node, ast.Import):
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name
for helper in helpers_by_file_get(module_name, []):
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)
return dict(imported_names_map)
def detect_unused_helper_functions(
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
) -> list[FunctionSource]:
"""Detect helper functions that are no longer called by the optimized entrypoint function.
Args:
code_context: The code optimization context containing helper functions
optimized_code: The optimized code to analyze
Returns:
List of FunctionSource objects representing unused helper functions
"""
try:
# Parse the optimized code to analyze function calls and imports
optimized_ast = ast.parse(optimized_code)
# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
entrypoint_function_ast = node
break
if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
return []
# First, analyze imports to build a mapping of imported names to their original qualified names
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
# Extract all function calls in the entrypoint function
called_function_names = set()
for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
# Regular function call: function_name()
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")
# Find helper functions that are no longer called
unused_helpers = []
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name
# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")
# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))
if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
return unused_helpers
except Exception as e:
logger.debug(f"Error detecting unused helper functions: {e}")
return []

View file

@ -41,13 +41,17 @@ from codeflash.code_utils.config_consts import (
N_TESTS_TO_GENERATE,
TOTAL_LOOPING_TIME,
)
from codeflash.code_utils.edit_generated_tests import (
add_runtime_comments_to_generated_tests,
remove_functions_from_generated_tests,
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
from codeflash.either import Failure, Success, is_successful
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
@ -279,10 +283,6 @@ class FunctionOptimizer:
},
)
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)
if best_optimization:
logger.info("Best candidate:")
code_print(best_optimization.candidate.source_code)
@ -309,10 +309,10 @@ class FunctionOptimizer:
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None,
)
self.log_successful_optimization(explanation, generated_tests, exp_type)
self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=best_optimization.candidate.source_code
code_context=code_context,
optimized_code=best_optimization.candidate.source_code,
original_helper_code=original_helper_code,
)
new_code, new_helper_code = self.reformat_code_and_helpers(
@ -335,6 +335,15 @@ class FunctionOptimizer:
if original_code_baseline.coverage_results
else "Coverage data not available"
)
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)
# Add runtime comments to generated tests before creating the PR
generated_tests = add_runtime_comments_to_generated_tests(
generated_tests,
original_code_baseline.benchmarking_test_results,
best_optimization.winning_benchmarking_test_results,
)
generated_tests_str = "\n\n".join(
[test.generated_original_test_source for test in generated_tests.generated_tests]
)
@ -359,6 +368,8 @@ class FunctionOptimizer:
original_helper_code,
self.function_to_optimize.file_path,
)
self.log_successful_optimization(explanation, generated_tests, exp_type)
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
if not best_optimization:
@ -426,7 +437,9 @@ class FunctionOptimizer:
code_print(candidate.source_code)
try:
did_update = self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=candidate.source_code
code_context=code_context,
optimized_code=candidate.source_code,
original_helper_code=original_helper_code,
)
if not did_update:
logger.warning(
@ -627,7 +640,7 @@ class FunctionOptimizer:
return new_code, new_helper_code
def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
) -> bool:
did_update = False
read_writable_functions_by_file_path = defaultdict(set)
@ -645,6 +658,12 @@ class FunctionOptimizer:
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,
)
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
# Revert unused helper functions to their original definitions
if unused_helpers:
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
return did_update
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:

View file

@ -30,12 +30,14 @@ class PicklePatcher:
"""Safely pickle an object, replacing unpicklable parts with placeholders.
Args:
----
obj: The object to pickle
protocol: The pickle protocol version to use
max_depth: Maximum recursion depth
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
-------
bytes: Pickled data with placeholders for unpicklable objects
"""
@ -46,9 +48,11 @@ class PicklePatcher:
"""Unpickle data that may contain placeholders.
Args:
----
pickled_data: Pickled data with possible placeholders
Returns:
-------
The unpickled object with placeholders for unpicklable parts
"""
@ -59,11 +63,13 @@ class PicklePatcher:
"""Create a placeholder for an unpicklable object.
Args:
----
obj: The original unpicklable object
error_msg: Error message explaining why it couldn't be pickled
path: Path to this object in the object graph
Returns:
-------
PicklePlaceholder: A placeholder object
"""
@ -91,12 +97,14 @@ class PicklePatcher:
"""Try to pickle an object using pickle first, then dill. If both fail, create a placeholder.
Args:
----
obj: The object to pickle
path: Path to this object in the object graph
protocol: The pickle protocol version to use
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
-------
tuple: (success, result) where success is a boolean and result is either:
- Pickled bytes if successful
- Error message if not successful
@ -123,6 +131,7 @@ class PicklePatcher:
"""Recursively try to pickle an object, replacing unpicklable parts with placeholders.
Args:
----
obj: The object to pickle
max_depth: Maximum recursion depth
path: Current path in the object graph
@ -130,6 +139,7 @@ class PicklePatcher:
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
-------
bytes: Pickled data with placeholders for unpicklable objects
"""
@ -185,6 +195,7 @@ class PicklePatcher:
"""Handle pickling for dictionary objects.
Args:
----
obj_dict: The dictionary to pickle
max_depth: Maximum recursion depth
error_msg: Error message from the original pickling attempt
@ -193,6 +204,7 @@ class PicklePatcher:
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
-------
bytes: Pickled data with placeholders for unpicklable objects
"""
@ -249,6 +261,7 @@ class PicklePatcher:
"""Handle pickling for sequence types (list, tuple, set).
Args:
----
obj_seq: The sequence to pickle
max_depth: Maximum recursion depth
error_msg: Error message from the original pickling attempt
@ -257,6 +270,7 @@ class PicklePatcher:
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
-------
bytes: Pickled data with placeholders for unpicklable objects
"""
@ -305,6 +319,7 @@ class PicklePatcher:
"""Handle pickling for custom objects with __dict__.
Args:
----
obj: The object to pickle
max_depth: Maximum recursion depth
error_msg: Error message from the original pickling attempt
@ -313,6 +328,7 @@ class PicklePatcher:
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
-------
bytes: Pickled data with placeholders for unpicklable objects
"""

View file

@ -18,6 +18,7 @@ class PicklePlaceholder:
"""Initialize a placeholder for an unpicklable object.
Args:
----
obj_type (str): The type name of the original object
obj_str (str): String representation of the original object
error_msg (str): The error message that occurred during pickling

View file

@ -55,7 +55,7 @@ class ProfileStats(pstats.Stats):
print(indent, self.total_calls, "function calls", end=" ", file=self.stream)
if self.total_calls != self.prim_calls:
print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream) # noqa: UP031
print(f"({self.prim_calls:d} primitive calls)", end=" ", file=self.stream)
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
print(file=self.stream)

View file

@ -29,7 +29,7 @@ codeflash/code_utils/time_utils.py
codeflash/code_utils/env_utils.py
codeflash/code_utils/config_consts.py
codeflash/code_utils/static_analysis.py
codeflash/code_utils/remove_generated_tests.py
codeflash/code_utils/edit_generated_tests.py
codeflash/cli_cmds/console_constants.py
codeflash/cli_cmds/logging_config.py
codeflash/cli_cmds/__init__.py

View file

@ -191,7 +191,9 @@ ignore = [
"T201",
"PGH004",
"S301",
"D104"
"D104",
"PERF203",
"LOG015"
]
[tool.ruff.lint.flake8-type-checking]

View file

@ -0,0 +1,464 @@
"""Tests for the add_runtime_comments_to_generated_tests functionality."""
from pathlib import Path
from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests
from codeflash.models.models import (
FunctionTestInvocation,
GeneratedTests,
GeneratedTestsList,
InvocationId,
TestResults,
TestType,
VerificationType,
)
class TestAddRuntimeComments:
"""Test cases for add_runtime_comments_to_generated_tests method."""
def create_test_invocation(
self, test_function_name: str, runtime: int, loop_index: int = 1, iteration_id: str = "1", did_pass: bool = True
) -> FunctionTestInvocation:
"""Helper to create test invocation objects."""
return FunctionTestInvocation(
loop_index=loop_index,
id=InvocationId(
test_module_path="test_module",
test_class_name=None,
test_function_name=test_function_name,
function_getting_tested="test_function",
iteration_id=iteration_id,
),
file_name=Path("test.py"),
did_pass=did_pass,
runtime=runtime,
test_framework="pytest",
test_type=TestType.GENERATED_REGRESSION,
return_value=None,
timed_out=False,
verification_type=VerificationType.FUNCTION_CALL,
)
def test_basic_runtime_comment_addition(self):
"""Test basic functionality of adding runtime comments."""
# Create test source code
test_source = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
# Add test invocations with different runtimes
original_invocation = self.create_test_invocation("test_bubble_sort", 500_000) # 500μs
optimized_invocation = self.create_test_invocation("test_bubble_sort", 300_000) # 300μs
original_test_results.add(original_invocation)
optimized_test_results.add(optimized_invocation)
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that comments were added
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
def test_multiple_test_functions(self):
"""Test handling multiple test functions in the same file."""
test_source = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
def test_quick_sort():
codeflash_output = quick_sort([5, 2, 8])
assert codeflash_output == [2, 5, 8]
def helper_function():
return "not a test"
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results for both functions
original_test_results = TestResults()
optimized_test_results = TestResults()
# Add test invocations for both test functions
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000))
original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000))
optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000))
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
modified_source = result.generated_tests[0].generated_original_test_source
# Check that comments were added to both test functions
assert "# 500μs -> 300μs" in modified_source
assert "# 800μs -> 600μs" in modified_source
# Helper function should not have comments
assert (
"helper_function():" in modified_source
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
)
def test_different_time_formats(self):
"""Test that different time ranges are formatted correctly with new precision rules."""
test_cases = [
(999, 500, "999ns -> 500ns"), # nanoseconds
(25_000, 18_000, "25.0μs -> 18.0μs"), # microseconds with precision
(500_000, 300_000, "500μs -> 300μs"), # microseconds full integers
(1_500_000, 800_000, "1.50ms -> 800μs"), # milliseconds with precision
(365_000_000, 290_000_000, "365ms -> 290ms"), # milliseconds full integers
(2_000_000_000, 1_500_000_000, "2.00s -> 1.50s"), # seconds with precision
]
for original_time, optimized_time, expected_comment in test_cases:
test_source = """def test_function():
codeflash_output = some_function()
assert codeflash_output is not None
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_function", original_time))
optimized_test_results.add(self.create_test_invocation("test_function", optimized_time))
# Test the functionality
result = add_runtime_comments_to_generated_tests(
generated_tests, original_test_results, optimized_test_results
)
modified_source = result.generated_tests[0].generated_original_test_source
assert f"# {expected_comment}" in modified_source
def test_missing_test_results(self):
"""Test behavior when test results are missing for a test function."""
test_source = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create empty test results
original_test_results = TestResults()
optimized_test_results = TestResults()
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that no comments were added
modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == test_source # Should be unchanged
def test_partial_test_results(self):
"""Test behavior when only one set of test results is available."""
test_source = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results with only original data
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000))
# No optimized results
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that no comments were added
modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == test_source # Should be unchanged
def test_multiple_runtimes_uses_minimum(self):
"""Test that when multiple runtimes exist, the minimum is used."""
test_source = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results with multiple loop iterations
original_test_results = TestResults()
optimized_test_results = TestResults()
# Add multiple runs with different runtimes
original_test_results.add(self.create_test_invocation("test_bubble_sort", 600_000, loop_index=1))
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, loop_index=2))
original_test_results.add(self.create_test_invocation("test_bubble_sort", 550_000, loop_index=3))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 350_000, loop_index=1))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, loop_index=2))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3))
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that minimum times were used (500μs -> 300μs)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
def test_no_codeflash_output_assignment(self):
"""Test behavior when test doesn't have codeflash_output assignment."""
test_source = """def test_bubble_sort():
result = bubble_sort([3, 1, 2])
assert result == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000))
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that no comments were added (no codeflash_output assignment)
modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == test_source # Should be unchanged
def test_invalid_python_code_handling(self):
"""Test behavior when test source code is invalid Python."""
test_source = """def test_bubble_sort(:
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
""" # Invalid syntax: extra colon
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000))
# Test the functionality - should handle parse error gracefully
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that original test is preserved when parsing fails
modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == test_source # Should be unchanged due to parse error
def test_multiple_generated_tests(self):
"""Test handling multiple generated test objects."""
test_source_1 = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
test_source_2 = """def test_quick_sort():
codeflash_output = quick_sort([5, 2, 8])
assert codeflash_output == [2, 5, 8]
"""
generated_test_1 = GeneratedTests(
generated_original_test_source=test_source_1,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior_1.py"),
perf_file_path=Path("test_perf_1.py"),
)
generated_test_2 = GeneratedTests(
generated_original_test_source=test_source_2,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior_2.py"),
perf_file_path=Path("test_perf_2.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test_1, generated_test_2])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000))
original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000))
optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000))
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that comments were added to both test files
modified_source_1 = result.generated_tests[0].generated_original_test_source
modified_source_2 = result.generated_tests[1].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source_1
assert "# 800μs -> 600μs" in modified_source_2
def test_preserved_test_attributes(self):
"""Test that other test attributes are preserved during modification."""
test_source = """def test_bubble_sort():
codeflash_output = bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
original_behavior_source = "behavior test source"
original_perf_source = "perf test source"
original_behavior_path = Path("test_behavior.py")
original_perf_path = Path("test_perf.py")
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source=original_behavior_source,
instrumented_perf_test_source=original_perf_source,
behavior_file_path=original_behavior_path,
perf_file_path=original_perf_path,
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000))
optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000))
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that other attributes are preserved
modified_test = result.generated_tests[0]
assert modified_test.instrumented_behavior_test_source == original_behavior_source
assert modified_test.instrumented_perf_test_source == original_perf_source
assert modified_test.behavior_file_path == original_behavior_path
assert modified_test.perf_file_path == original_perf_path
# Check that only the generated_original_test_source was modified
assert "# 500μs -> 300μs" in modified_test.generated_original_test_source
def test_multistatement_line_handling(self):
"""Test that runtime comments work correctly with multiple statements on one line."""
test_source = """def test_mutation_of_input():
# Test that the input list is mutated in-place and returned
arr = [3, 1, 2]
codeflash_output = sorter(arr); result = codeflash_output
assert result == [1, 2, 3]
assert arr == [1, 2, 3] # Input should be mutated
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=Path("test_behavior.py"),
perf_file_path=Path("test_perf.py"),
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
# Create test results
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_mutation_of_input", 19_000)) # 19μs
optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000)) # 14μs
# Test the functionality
result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results)
# Check that comments were added to the correct line
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 19.0μs -> 14.0μs" in modified_source
# Verify the comment is on the line with codeflash_output assignment
lines = modified_source.split("\n")
codeflash_line = None
for line in lines:
if "codeflash_output = sorter(arr)" in line:
codeflash_line = line
break
assert codeflash_line is not None, "Could not find codeflash_output assignment line"
assert "# 19.0μs -> 14.0μs" in codeflash_line, f"Comment not found in the correct line: {codeflash_line}"

View file

@ -70,7 +70,7 @@ def sorter(arr):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
final_output = code_path.read_text(encoding="utf-8")
assert "inconsequential_var = '123'" in final_output
@ -805,7 +805,8 @@ class MainClass:
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()"""
return HelperClass(self.name).helper_method()
"""
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@ -1140,8 +1141,8 @@ class TestResults(BaseModel):
)
assert (
new_code
== """from __future__ import annotations
new_code
== """from __future__ import annotations
import sys
from codeflash.verification.comparator import comparator
from enum import Enum
@ -1346,8 +1347,8 @@ def cosine_similarity_top_k(
project_root_path=Path(__file__).parent.parent.resolve(),
)
assert (
new_code
== '''import numpy as np
new_code
== '''import numpy as np
from pydantic.dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@dataclass(config=dict(arbitrary_types_allowed=True))
@ -1405,8 +1406,8 @@ def cosine_similarity_top_k(
)
assert (
new_helper_code
== '''import numpy as np
new_helper_code
== '''import numpy as np
from pydantic.dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@dataclass(config=dict(arbitrary_types_allowed=True))
@ -1663,7 +1664,6 @@ print("Hello world")
)
assert new_code == original_code
def test_global_reassignment() -> None:
original_code = """a=1
print("Hello world")
@ -1735,7 +1735,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -1811,7 +1811,7 @@ a=2
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -1888,7 +1888,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -1964,7 +1964,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -2041,7 +2041,7 @@ class NewClass:
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -2129,7 +2129,7 @@ a = 6
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
@ -2607,3 +2607,4 @@ def test_something():
code = final_module.code
# Should have both modifications
assert code==expected_code

View file

@ -1,13 +1,18 @@
import tempfile
from pathlib import Path
import os
import unittest.mock
from codeflash.discovery.functions_to_optimize import (
filter_files_optimized,
find_all_functions_in_file,
get_functions_to_optimize,
inspect_top_level_functions_or_methods,
filter_functions,
get_all_files_and_functions
)
from codeflash.verification.verification_utils import TestConfig
from codeflash.code_utils.compat import codeflash_temp_dir
def test_function_eligible_for_optimization() -> None:
@ -313,3 +318,241 @@ def test_filter_files_optimized():
assert filter_files_optimized(file_path_same_level, tests_root, ignore_paths, module_root)
assert filter_files_optimized(file_path_different_level, tests_root, ignore_paths, module_root)
assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root)
def test_filter_functions():
with tempfile.TemporaryDirectory() as temp_dir_str:
temp_dir = Path(temp_dir_str)
# Create a test file in the temporary directory
test_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
with test_file_path.open("w") as f:
f.write(
"""
import copy
def propagate_attributes(
nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
) -> dict[str, dict]:
modified_nodes = copy.deepcopy(nodes)
# Build an adjacency list for faster traversal
adjacency = {}
for edge in edges:
src = edge["source"]
tgt = edge["target"]
if src not in adjacency:
adjacency[src] = []
adjacency[src].append(tgt)
# Track visited nodes to avoid cycles
visited = set()
def traverse(node_id):
if node_id in visited:
return
visited.add(node_id)
# Propagate attribute from source node
if (
node_id != source_node_id
and source_node_id in modified_nodes
and attribute in modified_nodes[source_node_id]
):
if node_id in modified_nodes:
modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
attribute
]
# Continue propagation to neighbors
for neighbor in adjacency.get(node_id, []):
traverse(neighbor)
traverse(source_node_id)
return modified_nodes
def vanilla_function():
return "This is a vanilla function."
def not_in_checkpoint_function():
return "This function is not in the checkpoint."
"""
)
discovered = find_all_functions_in_file(test_file_path)
modified_functions = {test_file_path: discovered[test_file_path]}
filtered, count = filter_functions(
modified_functions,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
function_names = [fn.function_name for fn in filtered.get(test_file_path, [])]
assert "propagate_attributes" in function_names
assert count == 3
# Create a tests directory inside our temp directory
tests_root_dir = temp_dir.joinpath("tests")
tests_root_dir.mkdir(exist_ok=True)
test_file_path = tests_root_dir.joinpath("test_functions.py")
with test_file_path.open("w") as f:
f.write(
"""
def test_function_in_tests_dir():
return "This function is in a test directory and should be filtered out."
"""
)
discovered_test_file = find_all_functions_in_file(test_file_path)
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
filtered_test_file, count_test_file = filter_functions(
modified_functions_test,
tests_root=tests_root_dir,
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_test_file
assert count_test_file == 0
# Test ignored directory
ignored_dir = temp_dir.joinpath("ignored_dir")
ignored_dir.mkdir(exist_ok=True)
ignored_file_path = ignored_dir.joinpath("ignored_file.py")
with ignored_file_path.open("w") as f:
f.write("def ignored_func(): return 1")
discovered_ignored = find_all_functions_in_file(ignored_file_path)
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
filtered_ignored, count_ignored = filter_functions(
modified_functions_ignored,
tests_root=Path("tests"),
ignore_paths=[ignored_dir],
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_ignored
assert count_ignored == 0
# Test submodule paths
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
return_value=[str(temp_dir.joinpath("submodule_dir"))]):
submodule_dir = temp_dir.joinpath("submodule_dir")
submodule_dir.mkdir(exist_ok=True)
submodule_file_path = submodule_dir.joinpath("submodule_file.py")
with submodule_file_path.open("w") as f:
f.write("def submodule_func(): return 1")
discovered_submodule = find_all_functions_in_file(submodule_file_path)
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
filtered_submodule, count_submodule = filter_functions(
modified_functions_submodule,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_submodule
assert count_submodule == 0
# Test site packages
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages",
return_value=True):
site_package_file_path = temp_dir.joinpath("site_package_file.py")
with site_package_file_path.open("w") as f:
f.write("def site_package_func(): return 1")
discovered_site_package = find_all_functions_in_file(site_package_file_path)
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
filtered_site_package, count_site_package = filter_functions(
modified_functions_site_package,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_site_package
assert count_site_package == 0
# Test outside module root
parent_dir = temp_dir.parent
outside_module_root_path = parent_dir.joinpath("outside_module_root_file.py")
try:
with outside_module_root_path.open("w") as f:
f.write("def func_outside_module_root(): return 1")
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
filtered_outside_module, count_outside_module = filter_functions(
modified_functions_outside_module,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_outside_module
assert count_outside_module == 0
finally:
outside_module_root_path.unlink(missing_ok=True)
# Test invalid module name
invalid_module_file_path = temp_dir.joinpath("invalid-module-name.py")
with invalid_module_file_path.open("w") as f:
f.write("def func_in_invalid_module(): return 1")
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
filtered_invalid_module, count_invalid_module = filter_functions(
modified_functions_invalid_module,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
assert not filtered_invalid_module
assert count_invalid_module == 0
original_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
filtered_funcs, count = filter_functions(
modified_functions,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
)
assert "propagate_attributes" not in [fn.function_name for fn in filtered_funcs.get(original_file_path, [])]
assert count == 2
module_name = "test_get_functions_to_optimize"
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
filtered_checkpoint, count_checkpoint = filter_functions(
modified_functions,
tests_root=Path("tests"),
ignore_paths=[],
project_root=temp_dir,
module_root=temp_dir,
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
)
assert filtered_checkpoint.get(original_file_path)
assert count_checkpoint == 1
remaining_functions = [fn.function_name for fn in filtered_checkpoint.get(original_file_path, [])]
assert "not_in_checkpoint_function" in remaining_functions
assert "propagate_attributes" not in remaining_functions
assert "vanilla_function" not in remaining_functions
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir)
assert len(files_and_funcs) == 6

View file

@ -1,8 +1,7 @@
from pathlib import Path
import pytest
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
from codeflash.models.models import GeneratedTests, GeneratedTestsList

File diff suppressed because it is too large Load diff