initial implementation for tracing benchmarks using a plugin, and projecting speedup

This commit is contained in:
Alvin Ryanputra 2025-02-27 16:25:07 -08:00
parent a73b541159
commit 965e2c818c
22 changed files with 790 additions and 571 deletions

View file

@ -0,0 +1,28 @@
from code_to_optimize.bubble_sort import sorter
def calculate_pairwise_products(arr):
"""
Calculate the average of all pairwise products in the array.
"""
sum_of_products = 0
count = 0
for i in range(len(arr)):
for j in range(len(arr)):
if i != j:
sum_of_products += arr[i] * arr[j]
count += 1
# The average of all pairwise products
return sum_of_products / count if count > 0 else 0
def compute_and_sort(arr):
# Compute pairwise sums average
pairwise_average = calculate_pairwise_products(arr)
# Call sorter function
sorter(arr.copy())
return pairwise_average

View file

@ -1,6 +1,13 @@
import pytest
from code_to_optimize.bubble_sort import sorter from code_to_optimize.bubble_sort import sorter
def test_sort(benchmark): def test_sort(benchmark):
result = benchmark(sorter, list(reversed(range(5000)))) result = benchmark(sorter, list(reversed(range(5000))))
assert result == list(range(5000)) assert result == list(range(5000))
# This should not be picked up as a benchmark test
def test_sort2():
result = sorter(list(reversed(range(5000))))
assert result == list(range(5000))

View file

@ -0,0 +1,8 @@
from code_to_optimize.process_and_bubble_sort import compute_and_sort
from code_to_optimize.bubble_sort2 import sorter
def test_compute_and_sort(benchmark):
result = benchmark(compute_and_sort, list(reversed(range(5000))))
assert result == 6247083.5
def test_no_func(benchmark):
benchmark(sorter, list(reversed(range(5000))))

View file

View file

@ -0,0 +1,112 @@
import sqlite3
from pathlib import Path
from typing import Dict, Set
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
def get_function_benchmark_timings(trace_dir: Path, all_functions_to_optimize: list[FunctionToOptimize]) -> dict[str, dict[str, float]]:
"""Process all trace files in the given directory and extract timing data for the specified functions.
Args:
trace_dir: Path to the directory containing .trace files
all_functions_to_optimize: Set of FunctionToOptimize objects representing functions to include
Returns:
A nested dictionary where:
- Outer keys are function qualified names with file name
- Inner keys are benchmark names (trace filename without .trace extension)
- Values are function timing in milliseconds
"""
# Create a mapping of (filename, function_name, class_name) -> qualified_name for efficient lookups
function_lookup = {}
function_benchmark_timings = {}
for func in all_functions_to_optimize:
qualified_name = func.qualified_name_with_file_name
# Extract components (assumes Path.name gives only filename without directory)
filename = func.file_path
function_name = func.function_name
# Get class name if there's a parent
class_name = func.parents[0].name if func.parents else None
# Store in lookup dictionary
key = (filename, function_name, class_name)
function_lookup[key] = qualified_name
function_benchmark_timings[qualified_name] = {}
# Find all .trace files in the directory
trace_files = list(trace_dir.glob("*.trace"))
for trace_file in trace_files:
# Extract benchmark name from filename (without .trace)
benchmark_name = trace_file.stem
# Connect to the trace database
conn = sqlite3.connect(trace_file)
cursor = conn.cursor()
# For each function we're interested in, query the database directly
for (filename, function_name, class_name), qualified_name in function_lookup.items():
# Adjust query based on whether we have a class name
if class_name:
cursor.execute(
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND class_name = ?",
(f"%{filename}", function_name, class_name)
)
else:
cursor.execute(
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND (class_name IS NULL OR class_name = '')",
(f"%{filename}", function_name)
)
result = cursor.fetchone()
if result:
time_ns = result[0]
function_benchmark_timings[qualified_name][benchmark_name] = time_ns / 1e6 # Convert to milliseconds
conn.close()
return function_benchmark_timings
def get_benchmark_timings(trace_dir: Path) -> dict[str, float]:
"""Extract total benchmark timings from trace files.
Args:
trace_dir: Path to the directory containing .trace files
Returns:
A dictionary mapping benchmark names to their total execution time in milliseconds.
"""
benchmark_timings = {}
# Find all .trace files in the directory
trace_files = list(trace_dir.glob("*.trace"))
for trace_file in trace_files:
# Extract benchmark name from filename (without .trace extension)
benchmark_name = trace_file.stem
# Connect to the trace database
conn = sqlite3.connect(trace_file)
cursor = conn.cursor()
# Query the total_time table for the benchmark's total execution time
try:
cursor.execute("SELECT time_ns FROM total_time")
result = cursor.fetchone()
if result:
time_ns = result[0]
# Convert nanoseconds to milliseconds
benchmark_timings[benchmark_name] = time_ns / 1e6
except sqlite3.OperationalError:
# Handle case where total_time table might not exist
print(f"Warning: Could not get total time for benchmark {benchmark_name}")
conn.close()
return benchmark_timings

View file

@ -0,0 +1,79 @@
import pytest
from codeflash.tracer import Tracer
from pathlib import Path
class CodeFlashPlugin:
@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
parser.addoption(
"--functions",
action="store",
default="",
help="Comma-separated list of additional functions to trace"
)
parser.addoption(
"--benchmarks-root",
action="store",
default=".",
help="Root directory for benchmarks"
)
@staticmethod
def pytest_plugin_registered(plugin, manager):
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
manager.unregister(plugin)
@staticmethod
def pytest_collection_modifyitems(config, items):
if not config.getoption("--codeflash-trace"):
return
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
continue
item.add_marker(skip_no_benchmark)
@staticmethod
@pytest.fixture
def benchmark(request):
if not request.config.getoption("--codeflash-trace"):
return None
class Benchmark:
def __call__(self, func, *args, **kwargs):
func_name = func.__name__
test_name = request.node.name
additional_functions = request.config.getoption("--functions").split(",")
trace_functions = [f for f in additional_functions if f]
print("Tracing functions: ", trace_functions)
# Get benchmarks root directory from command line option
benchmarks_root = Path(request.config.getoption("--benchmarks-root"))
# Create .trace directory if it doesn't exist
trace_dir = benchmarks_root / '.codeflash_trace'
trace_dir.mkdir(exist_ok=True)
# Set output path to the .trace directory
output_path = trace_dir / f"{test_name}.trace"
tracer = Tracer(
output=str(output_path), # Convert Path to string for Tracer
functions=trace_functions,
max_function_count=256
)
with tracer:
result = func(*args, **kwargs)
return result
return Benchmark()

View file

@ -0,0 +1,15 @@
import sys
from plugin.plugin import CodeFlashPlugin
benchmarks_root = sys.argv[1]
function_list = sys.argv[2]
if __name__ == "__main__":
import pytest
try:
exitcode = pytest.main(
[benchmarks_root, "--benchmarks-root", benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "--functions", function_list], plugins=[CodeFlashPlugin()]
)
except Exception as e:
print(f"Failed to collect tests: {e!s}")
exitcode = -1

View file

@ -0,0 +1,20 @@
from __future__ import annotations
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from pathlib import Path
import subprocess
def trace_benchmarks_pytest(benchmarks_root: Path, project_root: Path, function_list: list[str] = []) -> None:
result = subprocess.run(
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
str(benchmarks_root),
",".join(function_list)
],
cwd=project_root,
check=False,
capture_output=True,
text=True,
)
print("stdout:", result.stdout)
print("stderr:", result.stderr)

View file

@ -62,6 +62,10 @@ def parse_args() -> Namespace:
) )
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs") parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
parser.add_argument("--version", action="store_true", help="Print the version of codeflash") parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks")
parser.add_argument(
"--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located."
)
args: Namespace = parser.parse_args() args: Namespace = parser.parse_args()
return process_and_validate_cmd_args(args) return process_and_validate_cmd_args(args)
@ -116,6 +120,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"disable_telemetry", "disable_telemetry",
"disable_imports_sorting", "disable_imports_sorting",
"git_remote", "git_remote",
"benchmarks_root"
] ]
for key in supported_keys: for key in supported_keys:
if key in pyproject_config and ( if key in pyproject_config and (
@ -127,23 +132,17 @@ def process_pyproject_config(args: Namespace) -> Namespace:
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None, "--tests-root must be specified" assert args.tests_root is not None, "--tests-root must be specified"
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory" assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
if args.benchmark:
if env_utils.get_pr_number() is not None: assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
assert env_utils.ensure_codeflash_api_key(), ( assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
"Codeflash API key not found. When running in a Github Actions Context, provide the " assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
"'CODEFLASH_API_KEY' environment variable as a secret.\n" "Codeflash API key not found. When running in a Github Actions Context, provide the "
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n" "'CODEFLASH_API_KEY' environment variable as a secret.\n"
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n" "You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
f"Here's a direct link: {get_github_secrets_page_url()}\n" "Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
"Exiting..." f"Here's a direct link: {get_github_secrets_page_url()}\n"
) "Exiting..."
)
repo = git.Repo(search_parent_directories=True)
owner, repo_name = get_repo_owner_and_name(repo)
require_github_app_or_exit(owner, repo_name)
if hasattr(args, "ignore_paths") and args.ignore_paths is not None: if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
normalized_ignore_paths = [] normalized_ignore_paths = []
for path in args.ignore_paths: for path in args.ignore_paths:

View file

@ -107,8 +107,6 @@ def discover_tests_pytest(
test_type = TestType.REPLAY_TEST test_type = TestType.REPLAY_TEST
elif "test_concolic_coverage" in test["test_file"]: elif "test_concolic_coverage" in test["test_file"]:
test_type = TestType.CONCOLIC_COVERAGE_TEST test_type = TestType.CONCOLIC_COVERAGE_TEST
elif test["test_type"] == "benchmark": # New condition for benchmark tests
test_type = TestType.BENCHMARK_TEST
else: else:
test_type = TestType.EXISTING_UNIT_TEST test_type = TestType.EXISTING_UNIT_TEST

View file

@ -121,7 +121,6 @@ class FunctionToOptimize:
method extends this with the module name from the project root. method extends this with the module name from the project root.
""" """
function_name: str function_name: str
file_path: Path file_path: Path
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
@ -145,6 +144,11 @@ class FunctionToOptimize:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
@property
def qualified_name_with_file_name(self) -> str:
class_name = self.parents[0].name if self.parents else None
return f"{self.file_path}:{(class_name + ':' if class_name else '')}{self.function_name}"
def get_functions_to_optimize( def get_functions_to_optimize(
optimize_all: str | None, optimize_all: str | None,

View file

@ -0,0 +1,54 @@
import sys
from typing import Any
# This script should not have any relation to the codeflash package, be careful with imports
cwd = sys.argv[1]
tests_root = sys.argv[2]
pickle_path = sys.argv[3]
collected_tests = []
pytest_rootdir = None
sys.path.insert(1, str(cwd))
class PytestCollectionPlugin:
def pytest_collection_finish(self, session) -> None:
global pytest_rootdir
collected_tests.extend(session.items)
pytest_rootdir = session.config.rootdir
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
test_results = []
for test in pytest_tests:
test_class = None
if test.cls:
test_class = test.parent.name
# Determine if this is a benchmark test by checking for the benchmark fixture
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
test_type = 'benchmark' if is_benchmark else 'regular'
test_results.append({
"test_file": str(test.path),
"test_class": test_class,
"test_function": test.name,
"test_type": test_type
})
return test_results
if __name__ == "__main__":
import pytest
try:
exitcode = pytest.main(
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
)
except Exception as e:
print(f"Failed to collect tests: {e!s}")
exitcode = -1
tests = parse_pytest_collection_results(collected_tests)
import pickle
with open(pickle_path, "wb") as f:
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)

View file

@ -29,17 +29,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s
test_class = None test_class = None
if test.cls: if test.cls:
test_class = test.parent.name test_class = test.parent.name
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
# Determine if this is a benchmark test by checking for the benchmark fixture
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
test_type = 'benchmark' if is_benchmark else 'regular'
test_results.append({
"test_file": str(test.path),
"test_class": test_class,
"test_function": test.name,
"test_type": test_type
})
return test_results return test_results

View file

@ -7,7 +7,7 @@ import shutil
import subprocess import subprocess
import time import time
import uuid import uuid
from collections import defaultdict, deque from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -21,12 +21,12 @@ from rich.tree import Tree
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import ( from codeflash.code_utils.code_utils import (
cleanup_paths, cleanup_paths,
file_name_from_test_module_name, file_name_from_test_module_name,
get_run_tmp_file, get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path, module_name_from_file_path,
) )
from codeflash.code_utils.config_consts import ( from codeflash.code_utils.config_consts import (
@ -37,7 +37,6 @@ from codeflash.code_utils.config_consts import (
) )
from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime from codeflash.code_utils.time_utils import humanize_runtime
@ -49,6 +48,7 @@ from codeflash.models.models import (
BestOptimization, BestOptimization,
CodeOptimizationContext, CodeOptimizationContext,
FunctionCalledInTest, FunctionCalledInTest,
FunctionParent,
GeneratedTests, GeneratedTests,
GeneratedTestsList, GeneratedTestsList,
OptimizationSet, OptimizationSet,
@ -57,9 +57,8 @@ from codeflash.models.models import (
TestFile, TestFile,
TestFiles, TestFiles,
TestingMode, TestingMode,
TestResults,
TestType,
) )
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.explanation import Explanation from codeflash.result.explanation import Explanation
@ -67,15 +66,18 @@ from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import parse_test_results from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests from codeflash.verification.test_results import TestResults, TestType
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests from codeflash.verification.verifier import generate_tests
if TYPE_CHECKING: if TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
import numpy as np
import numpy.typing as npt
from codeflash.either import Result from codeflash.either import Result
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
from codeflash.verification.verification_utils import TestConfig from codeflash.verification.verification_utils import TestConfig
@ -90,6 +92,8 @@ class FunctionOptimizer:
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None, function_to_optimize_ast: ast.FunctionDef | None = None,
aiservice_client: AiServiceClient | None = None, aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
total_benchmark_timings: dict[str, float] | None = None,
args: Namespace | None = None, args: Namespace | None = None,
) -> None: ) -> None:
self.project_root = test_cfg.project_root_path self.project_root = test_cfg.project_root_path
@ -118,6 +122,9 @@ class FunctionOptimizer:
self.function_trace_id: str = str(uuid.uuid4()) self.function_trace_id: str = str(uuid.uuid4())
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root) self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
def optimize_function(self) -> Result[BestOptimization, str]: def optimize_function(self) -> Result[BestOptimization, str]:
should_run_experiment = self.experiment_id is not None should_run_experiment = self.experiment_id is not None
logger.debug(f"Function Trace ID: {self.function_trace_id}") logger.debug(f"Function Trace ID: {self.function_trace_id}")
@ -134,10 +141,19 @@ class FunctionOptimizer:
with helper_function_path.open(encoding="utf8") as f: with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read() helper_code = f.read()
original_helper_code[helper_function_path] = helper_code original_helper_code[helper_function_path] = helper_code
if has_any_async_functions(code_context.read_writable_code):
return Failure("Codeflash does not support async functions in the code to optimize.") logger.info("Code to be optimized:")
code_print(code_context.read_writable_code) code_print(code_context.read_writable_code)
for module_abspath, helper_code_source in original_helper_code.items():
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
helper_code_source,
code_context.code_to_optimize_with_helpers,
module_abspath,
self.function_to_optimize.file_path,
self.args.project_root,
)
generated_test_paths = [ generated_test_paths = [
get_test_file_path( get_test_file_path(
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit" self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
@ -156,7 +172,7 @@ class FunctionOptimizer:
transient=True, transient=True,
): ):
generated_results = self.generate_tests_and_optimizations( generated_results = self.generate_tests_and_optimizations(
testgen_context_code=code_context.testgen_context_code, code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
read_writable_code=code_context.read_writable_code, read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code, read_only_context_code=code_context.read_only_context_code,
helper_functions=code_context.helper_functions, helper_functions=code_context.helper_functions,
@ -232,11 +248,10 @@ class FunctionOptimizer:
): ):
cleanup_paths(paths_to_cleanup) cleanup_paths(paths_to_cleanup)
return Failure("The threshold for test coverage was not met.") return Failure("The threshold for test coverage was not met.")
# request for new optimizations but don't block execution, check for completion later
# adding to control and experiment set but with same traceid
best_optimization = None best_optimization = None
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]): for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
if candidates is None: if candidates is None:
continue continue
@ -270,6 +285,20 @@ class FunctionOptimizer:
function_name=function_to_optimize_qualified_name, function_name=function_to_optimize_qualified_name,
file_path=self.function_to_optimize.file_path, file_path=self.function_to_optimize.file_path,
) )
speedup = explanation.speedup # eg. 1.2 means 1.2x faster
if self.args.benchmark:
fto_benchmark_timings = self.function_benchmark_timings[self.function_to_optimize.qualified_name_with_file_name]
for benchmark_name, og_benchmark_timing in fto_benchmark_timings.items():
print(f"Calculating speedup for benchmark {benchmark_name}")
total_benchmark_timing = self.total_benchmark_timings[benchmark_name]
# find out expected new benchmark timing, then calculate how much total benchmark was sped up. print out intermediate values
expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + og_benchmark_timing / speedup
print(f"Expected new benchmark timing: {expected_new_benchmark_timing}")
print(f"Original benchmark timing: {total_benchmark_timing}")
print(f"Benchmark speedup: {total_benchmark_timing / expected_new_benchmark_timing}")
speedup = total_benchmark_timing / expected_new_benchmark_timing
print(f"Speedup: {speedup}")
self.log_successful_optimization(explanation, generated_tests) self.log_successful_optimization(explanation, generated_tests)
@ -359,123 +388,94 @@ class FunctionOptimizer:
f"{self.function_to_optimize.qualified_name}" f"{self.function_to_optimize.qualified_name}"
) )
console.rule() console.rule()
candidates = deque(candidates) try:
# Start a new thread for AI service request, start loop in main thread for candidate_index, candidate in enumerate(candidates, start=1):
# check if aiservice request is complete, when it is complete, append result to the candidates list get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
future_line_profile_results = executor.submit( logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
self.aiservice_client.optimize_python_code_line_profiler, code_print(candidate.source_code)
source_code=code_context.read_writable_code, try:
dependency_code=code_context.read_only_context_code, did_update = self.replace_function_and_helpers_with_optimized_code(
trace_id=self.function_trace_id, code_context=code_context, optimized_code=candidate.source_code
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
num_candidates=10,
experiment_metadata=None,
)
try:
candidate_index = 0
done = False
original_len = len(candidates)
while candidates:
# for candidate_index, candidate in enumerate(candidates, start=1):
done = True if future_line_profile_results is None else future_line_profile_results.done()
if done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len+= len(line_profile_results)
logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
candidate_index += 1
candidate = candidates.popleft()
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
code_print(candidate.source_code)
try:
did_update = self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=candidate.source_code
)
if not did_update:
logger.warning(
"No functions were replaced in the optimized code. Skipping optimization candidate."
)
console.rule()
continue
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
)
continue
# Instrument codeflash capture
run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index,
baseline_results=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
) )
console.rule() if not did_update:
logger.warning(
if not is_successful(run_results): "No functions were replaced in the optimized code. Skipping optimization candidate."
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
speedup_ratios[candidate.optimization_id] = None
else:
candidate_result: OptimizedCandidateResult = run_results.unwrap()
best_test_runtime = candidate_result.best_test_runtime
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
perf_gain = performance_gain(
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
) )
speedup_ratios[candidate.optimization_id] = perf_gain
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
if speedup_critic(
candidate_result, original_code_baseline.runtime, best_runtime_until_now
) and quantity_of_tests_critic(candidate_result):
tree.add("This candidate is faster than the previous best candidate. 🚀")
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
tree.add(
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
f"(measured over {candidate_result.max_loop_count} "
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_behavioral_test_results=candidate_result.behavior_test_results,
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
)
best_runtime_until_now = best_test_runtime
else:
tree.add(
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
f"(measured over {candidate_result.max_loop_count} "
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
console.print(tree)
console.rule() console.rule()
continue
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)
self.write_code_and_helpers( self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
) )
continue
# Instrument codeflash capture
run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index,
baseline_results=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
)
console.rule()
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
speedup_ratios[candidate.optimization_id] = None
else:
candidate_result: OptimizedCandidateResult = run_results.unwrap()
best_test_runtime = candidate_result.best_test_runtime
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
perf_gain = performance_gain(
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
)
speedup_ratios[candidate.optimization_id] = perf_gain
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
if speedup_critic(
candidate_result, original_code_baseline.runtime, best_runtime_until_now
) and quantity_of_tests_critic(candidate_result):
tree.add("This candidate is faster than the previous best candidate. 🚀")
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
tree.add(
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
f"(measured over {candidate_result.max_loop_count} "
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_behavioral_test_results=candidate_result.behavior_test_results,
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
)
best_runtime_until_now = best_test_runtime
else:
tree.add(
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
f"(measured over {candidate_result.max_loop_count} "
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
console.print(tree)
console.rule()
except KeyboardInterrupt as e:
self.write_code_and_helpers( self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
) )
logger.exception(f"Optimization interrupted: {e}") except KeyboardInterrupt as e:
raise self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise
self.aiservice_client.log_results( self.aiservice_client.log_results(
function_trace_id=self.function_trace_id, function_trace_id=self.function_trace_id,
@ -575,6 +575,50 @@ class FunctionOptimizer:
return did_update return did_update
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]: def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
self.function_to_optimize, self.project_root, code_to_optimize
)
if self.function_to_optimize.parents:
function_class = self.function_to_optimize.parents[0].name
same_class_helper_methods = [
df
for df in helper_functions
if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
]
optimizable_methods = [
FunctionToOptimize(
df.qualified_name.split(".")[-1],
df.file_path,
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
None,
None,
)
for df in same_class_helper_methods
] + [self.function_to_optimize]
dedup_optimizable_methods = []
added_methods = set()
for method in reversed(optimizable_methods):
if f"{method.file_path}.{method.qualified_name}" not in added_methods:
dedup_optimizable_methods.append(method)
added_methods.add(f"{method.file_path}.{method.qualified_name}")
if len(dedup_optimizable_methods) > 1:
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
self.function_to_optimize_source_code,
code_to_optimize_with_helpers,
self.function_to_optimize.file_path,
self.function_to_optimize.file_path,
self.project_root,
helper_functions,
)
try: try:
new_code_ctx = code_context_extractor.get_code_optimization_context( new_code_ctx = code_context_extractor.get_code_optimization_context(
self.function_to_optimize, self.project_root self.function_to_optimize, self.project_root
@ -584,7 +628,7 @@ class FunctionOptimizer:
return Success( return Success(
CodeOptimizationContext( CodeOptimizationContext(
testgen_context_code=new_code_ctx.testgen_context_code, code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
read_writable_code=new_code_ctx.read_writable_code, read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code, read_only_context_code=new_code_ctx.read_only_context_code,
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
@ -686,7 +730,7 @@ class FunctionOptimizer:
def generate_tests_and_optimizations( def generate_tests_and_optimizations(
self, self,
testgen_context_code: str, code_to_optimize_with_helpers: str,
read_writable_code: str, read_writable_code: str,
read_only_context_code: str, read_only_context_code: str,
helper_functions: list[FunctionSource], helper_functions: list[FunctionSource],
@ -701,7 +745,7 @@ class FunctionOptimizer:
# Submit the test generation task as future # Submit the test generation task as future
future_tests = self.generate_and_instrument_tests( future_tests = self.generate_and_instrument_tests(
executor, executor,
testgen_context_code, code_to_optimize_with_helpers,
[definition.fully_qualified_name for definition in helper_functions], [definition.fully_qualified_name for definition in helper_functions],
generated_test_paths, generated_test_paths,
generated_perf_test_paths, generated_perf_test_paths,
@ -790,7 +834,6 @@ class FunctionOptimizer:
original_helper_code: dict[Path, str], original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]], file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage # For the original function - run the tests and get the runtime, plus coverage
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"): with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"] assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
@ -831,31 +874,11 @@ class FunctionOptimizer:
) )
console.rule() console.rule()
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.") return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(coverage_results, self.args.test_framework): if not coverage_critic(
coverage_results, self.args.test_framework
):
return Failure("The threshold for test coverage was not met.") return Failure("The threshold for test coverage was not met.")
if test_framework == "pytest": if test_framework == "pytest":
try:
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
line_profile_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LINE_PROFILE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME,
enable_coverage=False,
code_context=code_context,
line_profiler_output_file=line_profiler_output_file,
)
finally:
# Remove codeflash capture
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if line_profile_results["str_out"] == "":
logger.warning(
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
console.rule()
benchmarking_results, _ = self.run_and_parse_tests( benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE, testing_type=TestingMode.PERFORMANCE,
test_env=test_env, test_env=test_env,
@ -894,6 +917,7 @@ class FunctionOptimizer:
) )
console.rule() console.rule()
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
functions_to_remove = [ functions_to_remove = [
result.id.test_function_name result.id.test_function_name
@ -927,7 +951,6 @@ class FunctionOptimizer:
benchmarking_test_results=benchmarking_results, benchmarking_test_results=benchmarking_results,
runtime=total_timing, runtime=total_timing,
coverage_results=coverage_results, coverage_results=coverage_results,
line_profile_results=line_profile_results,
), ),
functions_to_remove, functions_to_remove,
) )
@ -1063,77 +1086,59 @@ class FunctionOptimizer:
pytest_max_loops: int = 100_000, pytest_max_loops: int = 100_000,
code_context: CodeOptimizationContext | None = None, code_context: CodeOptimizationContext | None = None,
unittest_loop_index: int | None = None, unittest_loop_index: int | None = None,
line_profiler_output_file: Path | None = None, ) -> tuple[TestResults, CoverageData | None]:
) -> tuple[TestResults | dict, CoverageData | None]:
coverage_database_file = None coverage_database_file = None
coverage_config_file = None
try: try:
if testing_type == TestingMode.BEHAVIOR: if testing_type == TestingMode.BEHAVIOR:
result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests( result_file_path, run_result, coverage_database_file = run_behavioral_tests(
test_files, test_files,
test_framework=self.test_cfg.test_framework, test_framework=self.test_cfg.test_framework,
cwd=self.project_root, cwd=self.project_root,
test_env=test_env, test_env=test_env,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_cmd=self.test_cfg.pytest_cmd,
verbose=True, verbose=True,
enable_coverage=enable_coverage, enable_coverage=enable_coverage,
) )
elif testing_type == TestingMode.LINE_PROFILE:
result_file_path, run_result = run_line_profile_tests(
test_files,
cwd=self.project_root,
test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_min_loops,
test_framework=self.test_cfg.test_framework,
line_profiler_output_file=line_profiler_output_file,
)
elif testing_type == TestingMode.PERFORMANCE: elif testing_type == TestingMode.PERFORMANCE:
result_file_path, run_result = run_benchmarking_tests( result_file_path, run_result = run_benchmarking_tests(
test_files, test_files,
cwd=self.project_root, cwd=self.project_root,
test_env=test_env, test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_target_runtime_seconds=testing_time, pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops, pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_max_loops, pytest_max_loops=pytest_max_loops,
test_framework=self.test_cfg.test_framework, test_framework=self.test_cfg.test_framework,
) )
else: else:
msg = f"Unexpected testing type: {testing_type}" raise ValueError(f"Unexpected testing type: {testing_type}")
raise ValueError(msg)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.exception( logger.exception(
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error" f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
) )
return TestResults(), None return TestResults(), None
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR: if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
logger.debug( logger.debug(
f"Nonzero return code {run_result.returncode} when running tests in " f'Nonzero return code {run_result.returncode} when running tests in '
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n" f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n'
f"stdout: {run_result.stdout}\n" f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n" f"stderr: {run_result.stderr}\n"
) )
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]: # print(test_files)
results, coverage_results = parse_test_results( results, coverage_results = parse_test_results(
test_xml_path=result_file_path, test_xml_path=result_file_path,
test_files=test_files, test_files=test_files,
test_config=self.test_cfg, test_config=self.test_cfg,
optimization_iteration=optimization_iteration, optimization_iteration=optimization_iteration,
run_result=run_result, run_result=run_result,
unittest_loop_index=unittest_loop_index, unittest_loop_index=unittest_loop_index,
function_name=self.function_to_optimize.function_name, function_name=self.function_to_optimize.function_name,
source_file=self.function_to_optimize.file_path, source_file=self.function_to_optimize.file_path,
code_context=code_context, code_context=code_context,
coverage_database_file=coverage_database_file, coverage_database_file=coverage_database_file,
coverage_config_file=coverage_config_file, )
)
else:
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
return results, coverage_results return results, coverage_results
def generate_and_instrument_tests( def generate_and_instrument_tests(
@ -1163,3 +1168,4 @@ class FunctionOptimizer:
zip(generated_test_paths, generated_perf_test_paths) zip(generated_test_paths, generated_perf_test_paths)
) )
] ]

View file

@ -8,7 +8,8 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils import env_utils from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import normalize_code, normalize_node from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.code_utils import get_run_tmp_file
@ -16,10 +17,12 @@ from codeflash.code_utils.static_analysis import analyze_imported_modules, get_f
from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.either import is_successful from codeflash.either import is_successful
from codeflash.models.models import TestType, ValidCode from codeflash.models.models import TestFiles, ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.telemetry.posthog_cf import ph from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.test_results import TestType
from codeflash.verification.verification_utils import TestConfig from codeflash.verification.verification_utils import TestConfig
from codeflash.benchmarking.get_trace_info import get_function_benchmark_timings, get_benchmark_timings
if TYPE_CHECKING: if TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
@ -50,6 +53,8 @@ class Optimizer:
function_to_optimize_ast: ast.FunctionDef | None = None, function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "", function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
total_benchmark_timings: dict[str, float] | None = None,
) -> FunctionOptimizer: ) -> FunctionOptimizer:
return FunctionOptimizer( return FunctionOptimizer(
function_to_optimize=function_to_optimize, function_to_optimize=function_to_optimize,
@ -59,6 +64,8 @@ class Optimizer:
function_to_optimize_ast=function_to_optimize_ast, function_to_optimize_ast=function_to_optimize_ast,
aiservice_client=self.aiservice_client, aiservice_client=self.aiservice_client,
args=self.args, args=self.args,
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
) )
def run(self) -> None: def run(self) -> None:
@ -80,6 +87,23 @@ class Optimizer:
project_root=self.args.project_root, project_root=self.args.project_root,
module_root=self.args.module_root, module_root=self.args.module_root,
) )
if self.args.benchmark:
all_functions_to_optimize = [
function
for functions_list in file_to_funcs_to_optimize.values()
for function in functions_list
]
logger.info(f"Tracing existing benchmarks for {len(all_functions_to_optimize)} functions")
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.project_root, [fto.qualified_name_with_file_name for fto in all_functions_to_optimize])
logger.info("Finished tracing existing benchmarks")
trace_dir = Path(self.args.benchmarks_root) / ".codeflash_trace"
function_benchmark_timings = get_function_benchmark_timings(trace_dir, all_functions_to_optimize)
print(function_benchmark_timings)
total_benchmark_timings = get_benchmark_timings(trace_dir)
print("Total benchmark timings:")
print(total_benchmark_timings)
# for function in fully_qualified_function_names:
optimizations_found: int = 0 optimizations_found: int = 0
function_iterator_count: int = 0 function_iterator_count: int = 0
@ -93,6 +117,8 @@ class Optimizer:
logger.info("No functions found to optimize. Exiting…") logger.info("No functions found to optimize. Exiting…")
return return
console.rule()
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}")
console.rule() console.rule()
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
@ -136,7 +162,6 @@ class Optimizer:
validated_original_code[analysis.file_path] = ValidCode( validated_original_code[analysis.file_path] = ValidCode(
source_code=callee_original_code, normalized_code=normalized_callee_original_code source_code=callee_original_code, normalized_code=normalized_callee_original_code
) )
if has_syntax_error: if has_syntax_error:
continue continue
@ -146,7 +171,7 @@ class Optimizer:
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: " f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
f"{function_to_optimize.qualified_name}" f"{function_to_optimize.qualified_name}"
) )
console.rule()
if not ( if not (
function_to_optimize_ast := get_first_top_level_function_or_method_ast( function_to_optimize_ast := get_first_top_level_function_or_method_ast(
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
@ -157,12 +182,17 @@ class Optimizer:
f"Skipping optimization." f"Skipping optimization."
) )
continue continue
function_optimizer = self.create_function_optimizer( if self.args.benchmark:
function_to_optimize,
function_to_optimize_ast, function_optimizer = self.create_function_optimizer(
function_to_tests, function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings, total_benchmark_timings
validated_original_code[original_module_path].source_code, )
) else:
function_optimizer = self.create_function_optimizer(
function_to_optimize, function_to_optimize_ast, function_to_tests,
validated_original_code[original_module_path].source_code
)
best_optimization = function_optimizer.optimize_function() best_optimization = function_optimizer.optimize_function()
if is_successful(best_optimization): if is_successful(best_optimization):
optimizations_found += 1 optimizations_found += 1
@ -191,6 +221,7 @@ class Optimizer:
get_run_tmp_file.tmpdir.cleanup() get_run_tmp_file.tmpdir.cleanup()
def run_with_args(args: Namespace) -> None: def run_with_args(args: Namespace) -> None:
optimizer = Optimizer(args) optimizer = Optimizer(args)
optimizer.run() optimizer.run()

View file

@ -18,21 +18,19 @@ import marshal
import os import os
import pathlib import pathlib
import pickle import pickle
import re
import sqlite3 import sqlite3
import sys import sys
import threading
import time import time
from argparse import ArgumentParser
from collections import defaultdict from collections import defaultdict
from copy import copy
from io import StringIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar from types import FrameType
from typing import Any, ClassVar, List
import dill import dill
import isort import isort
from rich.align import Align
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from codeflash.cli_cmds.cli import project_root_from_module_root from codeflash.cli_cmds.cli import project_root_from_module_root
from codeflash.cli_cmds.console import console from codeflash.cli_cmds.console import console
@ -42,34 +40,14 @@ from codeflash.discovery.functions_to_optimize import filter_files_optimized
from codeflash.tracing.replay_test import create_trace_replay_test from codeflash.tracing.replay_test import create_trace_replay_test
from codeflash.tracing.tracing_utils import FunctionModules from codeflash.tracing.tracing_utils import FunctionModules
from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verification_utils import get_test_file_path
# import warnings
if TYPE_CHECKING: # warnings.filterwarnings("ignore", category=dill.PickleWarning)
from types import FrameType, TracebackType # warnings.filterwarnings("ignore", category=DeprecationWarning)
class FakeCode:
def __init__(self, filename: str, line: int, name: str) -> None:
self.co_filename = filename
self.co_line = line
self.co_name = name
self.co_firstlineno = 0
def __repr__(self) -> str:
return repr((self.co_filename, self.co_line, self.co_name, None))
class FakeFrame:
def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None:
self.f_code = code
self.f_back = prior
self.f_locals: dict = {}
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger. # Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
class Tracer: class Tracer:
"""Use this class as a 'with' context manager to trace a function call. """Use this class as a 'with' context manager to trace a function call,
input arguments, and profiling info.
Traces function calls, input arguments, and profiling info.
""" """
def __init__( def __init__(
@ -81,9 +59,7 @@ class Tracer:
max_function_count: int = 256, max_function_count: int = 256,
timeout: int | None = None, # seconds timeout: int | None = None, # seconds
) -> None: ) -> None:
"""Use this class to trace function calls. """:param output: The path to the output trace file
:param output: The path to the output trace file
:param functions: List of functions to trace. If None, trace all functions :param functions: List of functions to trace. If None, trace all functions
:param disable: Disable the tracer if True :param disable: Disable the tracer if True
:param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered :param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered
@ -94,9 +70,7 @@ class Tracer:
if functions is None: if functions is None:
functions = [] functions = []
if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1": if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1":
console.rule( console.print("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE")
"Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red"
)
disable = True disable = True
self.disable = disable self.disable = disable
if self.disable: if self.disable:
@ -111,7 +85,7 @@ class Tracer:
self.con = None self.con = None
self.output_file = Path(output).resolve() self.output_file = Path(output).resolve()
self.functions = functions self.functions = functions
self.function_modules: list[FunctionModules] = [] self.function_modules: List[FunctionModules] = []
self.function_count = defaultdict(int) self.function_count = defaultdict(int)
self.current_file_path = Path(__file__).resolve() self.current_file_path = Path(__file__).resolve()
self.ignored_qualified_functions = { self.ignored_qualified_functions = {
@ -121,10 +95,10 @@ class Tracer:
self.max_function_count = max_function_count self.max_function_count = max_function_count
self.config, found_config_path = parse_config_file(config_file_path) self.config, found_config_path = parse_config_file(config_file_path)
self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path) self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path)
console.rule(f"Project Root: {self.project_root}", style="bold blue") print("project_root", self.project_root)
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"} self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001 self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
assert timeout is None or timeout > 0, "Timeout should be greater than 0" assert timeout is None or timeout > 0, "Timeout should be greater than 0"
self.timeout = timeout self.timeout = timeout
@ -145,44 +119,48 @@ class Tracer:
def __enter__(self) -> None: def __enter__(self) -> None:
if self.disable: if self.disable:
return return
if getattr(Tracer, "used_once", False):
console.print( # if getattr(Tracer, "used_once", False):
"Codeflash: Tracer can only be used once per program run. " # console.print(
"Please only enable the Tracer once. Skipping tracing this section." # "Codeflash: Tracer can only be used once per program run. "
) # "Please only enable the Tracer once. Skipping tracing this section."
self.disable = True # )
return # self.disable = True
Tracer.used_once = True # return
# Tracer.used_once = True
if pathlib.Path(self.output_file).exists(): if pathlib.Path(self.output_file).exists():
console.rule("Removing existing trace file", style="bold red") console.print("Codeflash: Removing existing trace file")
console.rule()
pathlib.Path(self.output_file).unlink(missing_ok=True) pathlib.Path(self.output_file).unlink(missing_ok=True)
self.con = sqlite3.connect(self.output_file, check_same_thread=False) self.con = sqlite3.connect(self.output_file)
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("""PRAGMA synchronous = OFF""") cur.execute("""PRAGMA synchronous = OFF""")
cur.execute("""PRAGMA journal_mode = WAL""")
# TODO: Check out if we need to export the function test name as well # TODO: Check out if we need to export the function test name as well
cur.execute( cur.execute(
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, " "CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)" "line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
) )
console.rule("Codeflash: Traced Program Output Begin", style="bold blue") console.print("Codeflash: Tracing started!")
frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001 frame = sys._getframe(0) # Get this frame and simulate a call to it
self.dispatch["call"](self, frame, 0) self.dispatch["call"](self, frame, 0)
self.start_time = time.time() self.start_time = time.time()
sys.setprofile(self.trace_callback) sys.setprofile(self.trace_callback)
threading.setprofile(self.trace_callback)
def __exit__( def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
if self.disable: if self.disable:
return return
sys.setprofile(None) sys.setprofile(None)
self.con.commit() self.con.commit()
console.rule("Codeflash: Traced Program Output End", style="bold blue") # Check if any functions were actually traced
if self.trace_count == 0:
self.con.close()
# Delete the trace file if no functions were traced
if self.output_file.exists():
self.output_file.unlink()
console.print("Codeflash: No functions were traced. Removing trace database.")
return
self.create_stats() self.create_stats()
cur = self.con.cursor() cur = self.con.cursor()
@ -226,13 +204,14 @@ class Tracer:
test_framework=self.config["test_framework"], test_framework=self.config["test_framework"],
max_run_count=self.max_function_count, max_run_count=self.max_function_count,
) )
function_path = "_".join(self.functions) if self.functions else self.file_being_called_from # Need a better way to store the replay test
# function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
function_path = self.file_being_called_from
test_file_path = get_test_file_path( test_file_path = get_test_file_path(
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay" test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
) )
replay_test = isort.code(replay_test) replay_test = isort.code(replay_test)
with open(test_file_path, "w", encoding="utf8") as file:
with Path(test_file_path).open("w", encoding="utf8") as file:
file.write(replay_test) file.write(replay_test)
console.print( console.print(
@ -242,27 +221,25 @@ class Tracer:
overflow="ignore", overflow="ignore",
) )
def tracer_logic(self, frame: FrameType, event: str) -> None: def tracer_logic(self, frame: FrameType, event: str):
if event != "call": if event != "call":
return return
if self.timeout is not None and (time.time() - self.start_time) > self.timeout: if self.timeout is not None:
sys.setprofile(None) if (time.time() - self.start_time) > self.timeout:
threading.setprofile(None) sys.setprofile(None)
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.") console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
return return
code = frame.f_code code = frame.f_code
file_name = Path(code.co_filename).resolve() file_name = Path(code.co_filename).resolve()
# TODO : It currently doesn't log the last return call from the first function # TODO : It currently doesn't log the last return call from the first function
if code.co_name in self.ignored_functions: if code.co_name in self.ignored_functions:
return return
if not file_name.is_relative_to(self.project_root):
return
if not file_name.exists(): if not file_name.exists():
return return
if self.functions and code.co_name not in self.functions: # if self.functions:
return # if code.co_name not in self.functions:
# return
class_name = None class_name = None
arguments = frame.f_locals arguments = frame.f_locals
try: try:
@ -274,12 +251,16 @@ class Tracer:
class_name = arguments["self"].__class__.__name__ class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"): elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__ class_name = arguments["cls"].__name__
except: # noqa: E722 except:
# someone can override the getattr method and raise an exception. I'm looking at you wrapt # someone can override the getattr method and raise an exception. I'm looking at you wrapt
return return
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
if function_qualified_name in self.ignored_qualified_functions: if function_qualified_name in self.ignored_qualified_functions:
return return
if self.functions and function_qualified_name not in self.functions:
return
if function_qualified_name not in self.function_count: if function_qualified_name not in self.function_count:
# seeing this function for the first time # seeing this function for the first time
self.function_count[function_qualified_name] = 0 self.function_count[function_qualified_name] = 0
@ -354,14 +335,17 @@ class Tracer:
self.next_insert = 1000 self.next_insert = 1000
self.con.commit() self.con.commit()
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None: def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None:
# profiler section # profiler section
timer = self.timer timer = self.timer
t = timer() - self.t - self.bias t = timer() - self.t - self.bias
if event == "c_call": if event == "c_call":
self.c_func_name = arg.__name__ self.c_func_name = arg.__name__
prof_success = bool(self.dispatch[event](self, frame, t)) if self.dispatch[event](self, frame, t):
prof_success = True
else:
prof_success = False
# tracer section # tracer section
self.tracer_logic(frame, event) self.tracer_logic(frame, event)
# measure the time as the last thing before return # measure the time as the last thing before return
@ -370,60 +354,45 @@ class Tracer:
else: else:
self.t = timer() - t # put back unrecorded delta self.t = timer() - t # put back unrecorded delta
def trace_dispatch_call(self, frame: FrameType, t: int) -> int: def trace_dispatch_call(self, frame, t):
"""Handle call events in the profiler.""" if self.cur and frame.f_back is not self.cur[-2]:
rpt, rit, ret, rfn, rframe, rcur = self.cur
if not isinstance(rframe, Tracer.fake_frame):
assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back)
self.trace_dispatch_return(rframe, 0)
assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3])
fcode = frame.f_code
arguments = frame.f_locals
class_name = None
try: try:
# In multi-threaded contexts, we need to be more careful about frame comparisons if (
if self.cur and frame.f_back is not self.cur[-2]: "self" in arguments
# This happens when we're in a different thread and hasattr(arguments["self"], "__class__")
rpt, rit, ret, rfn, rframe, rcur = self.cur and hasattr(arguments["self"].__class__, "__name__")
):
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
except:
pass
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
self.cur = (t, 0, 0, fn, frame, self.cur)
timings = self.timings
if fn in timings:
cc, ns, tt, ct, callers = timings[fn]
timings[fn] = cc, ns + 1, tt, ct, callers
else:
timings[fn] = 0, 0, 0, 0, {}
return 1
# Only attempt to handle the frame mismatch if we have a valid rframe def trace_dispatch_exception(self, frame, t):
if (
not isinstance(rframe, FakeFrame)
and hasattr(rframe, "f_back")
and hasattr(frame, "f_back")
and rframe.f_back is frame.f_back
):
self.trace_dispatch_return(rframe, 0)
# Get function information
fcode = frame.f_code
arguments = frame.f_locals
class_name = None
try:
if (
"self" in arguments
and hasattr(arguments["self"], "__class__")
and hasattr(arguments["self"].__class__, "__name__")
):
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
except Exception: # noqa: BLE001, S110
pass
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
self.cur = (t, 0, 0, fn, frame, self.cur)
timings = self.timings
if fn in timings:
cc, ns, tt, ct, callers = timings[fn]
timings[fn] = cc, ns + 1, tt, ct, callers
else:
timings[fn] = 0, 0, 0, 0, {}
return 1 # noqa: TRY300
except Exception: # noqa: BLE001
# Handle any errors gracefully
return 0
def trace_dispatch_exception(self, frame: FrameType, t: int) -> int:
rpt, rit, ret, rfn, rframe, rcur = self.cur rpt, rit, ret, rfn, rframe, rcur = self.cur
if (rframe is not frame) and rcur: if (rframe is not frame) and rcur:
return self.trace_dispatch_return(rframe, t) return self.trace_dispatch_return(rframe, t)
self.cur = rpt, rit + t, ret, rfn, rframe, rcur self.cur = rpt, rit + t, ret, rfn, rframe, rcur
return 1 return 1
def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int: def trace_dispatch_c_call(self, frame, t):
fn = ("", 0, self.c_func_name, None) fn = ("", 0, self.c_func_name, None)
self.cur = (t, 0, 0, fn, frame, self.cur) self.cur = (t, 0, 0, fn, frame, self.cur)
timings = self.timings timings = self.timings
@ -434,27 +403,15 @@ class Tracer:
timings[fn] = 0, 0, 0, 0, {} timings[fn] = 0, 0, 0, 0, {}
return 1 return 1
def trace_dispatch_return(self, frame: FrameType, t: int) -> int: def trace_dispatch_return(self, frame, t):
if not self.cur or not self.cur[-2]:
return 0
# In multi-threaded environments, frames can get mismatched
if frame is not self.cur[-2]: if frame is not self.cur[-2]:
# Don't assert in threaded environments - frames can legitimately differ assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3])
if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back: self.trace_dispatch_return(self.cur[-2], 0)
self.trace_dispatch_return(self.cur[-2], 0)
else:
# We're in a different thread or context, can't continue with this frame
return 0
# Prefix "r" means part of the Returning or exiting frame. # Prefix "r" means part of the Returning or exiting frame.
# Prefix "p" means part of the Previous or Parent or older frame. # Prefix "p" means part of the Previous or Parent or older frame.
rpt, rit, ret, rfn, frame, rcur = self.cur rpt, rit, ret, rfn, frame, rcur = self.cur
# Guard against invalid rcur (w threading)
if not rcur:
return 0
rit = rit + t rit = rit + t
frame_total = rit + ret frame_total = rit + ret
@ -462,9 +419,6 @@ class Tracer:
self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur
timings = self.timings timings = self.timings
if rfn not in timings:
# w threading, rfn can be missing
timings[rfn] = 0, 0, 0, 0, {}
cc, ns, tt, ct, callers = timings[rfn] cc, ns, tt, ct, callers = timings[rfn]
if not ns: if not ns:
# This is the only occurrence of the function on the stack. # This is the only occurrence of the function on the stack.
@ -486,7 +440,7 @@ class Tracer:
return 1 return 1
dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = { dispatch: ClassVar[dict[str, callable]] = {
"call": trace_dispatch_call, "call": trace_dispatch_call,
"exception": trace_dispatch_exception, "exception": trace_dispatch_exception,
"return": trace_dispatch_return, "return": trace_dispatch_return,
@ -495,13 +449,32 @@ class Tracer:
"c_return": trace_dispatch_return, "c_return": trace_dispatch_return,
} }
def simulate_call(self, name: str) -> None: class fake_code:
code = FakeCode("profiler", 0, name) def __init__(self, filename, line, name):
pframe = self.cur[-2] if self.cur else None self.co_filename = filename
frame = FakeFrame(code, pframe) self.co_line = line
self.co_name = name
self.co_firstlineno = 0
def __repr__(self):
return repr((self.co_filename, self.co_line, self.co_name, None))
class fake_frame:
def __init__(self, code, prior):
self.f_code = code
self.f_back = prior
self.f_locals = {}
def simulate_call(self, name):
code = self.fake_code("profiler", 0, name)
if self.cur:
pframe = self.cur[-2]
else:
pframe = None
frame = self.fake_frame(code, pframe)
self.dispatch["call"](self, frame, 0) self.dispatch["call"](self, frame, 0)
def simulate_cmd_complete(self) -> None: def simulate_cmd_complete(self):
get_time = self.timer get_time = self.timer
t = get_time() - self.t t = get_time() - self.t
while self.cur[-1]: while self.cur[-1]:
@ -511,174 +484,60 @@ class Tracer:
t = 0 t = 0
self.t = get_time() - t self.t = get_time() - t
def print_stats(self, sort: str | int | tuple = -1) -> None: def print_stats(self, sort=-1):
if not self.stats: import pstats
console.print("Codeflash: No stats available to print")
self.total_tt = 0
return
if not isinstance(sort, tuple): if not isinstance(sort, tuple):
sort = (sort,) sort = (sort,)
# The following code customizes the default printing behavior to
# print in milliseconds.
s = StringIO()
stats_obj = pstats.Stats(copy(self), stream=s)
stats_obj.strip_dirs().sort_stats(*sort).print_stats(100)
self.total_tt = stats_obj.total_tt
console.print("total_tt", self.total_tt)
raw_stats = s.getvalue()
m = re.search(r"function calls?.*in (\d+)\.\d+ (seconds?)", raw_stats)
total_time = None
if m:
total_time = int(m.group(1))
if total_time is None:
console.print("Failed to get total time from stats")
total_time_ms = total_time / 1e6
raw_stats = re.sub(
r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats
)
match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +"
m = re.findall(match_pattern, raw_stats, re.MULTILINE)
ms_times = []
for tottime, percall, cumtime, percall_cum in m:
tottime_ms = int(tottime) / 1e6
percall_ms = int(percall) / 1e6
cumtime_ms = int(cumtime) / 1e6
percall_cum_ms = int(percall_cum) / 1e6
ms_times.append([tottime_ms, percall_ms, cumtime_ms, percall_cum_ms])
split_stats = raw_stats.split("\n")
new_stats = []
# First, convert stats to make them pstats-compatible replace_pattern = r"^( *[\d\/]+) +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(.*)"
try: times_index = 0
# Initialize empty collections for pstats for line in split_stats:
self.files = [] if times_index >= len(ms_times):
self.top_level = [] replaced = line
else:
# Create entirely new dictionaries instead of modifying existing ones replaced, n = re.subn(
new_stats = {} replace_pattern,
new_timings = {} rf"\g<1>{ms_times[times_index][0]:8.3f} {ms_times[times_index][1]:8.3f} {ms_times[times_index][2]:8.3f} {ms_times[times_index][3]:8.3f} \g<6>",
line,
# Convert stats dictionary count=1,
stats_items = list(self.stats.items())
for func, stats_data in stats_items:
try:
# Make sure we have 5 elements in stats_data
if len(stats_data) != 5:
console.print(f"Skipping malformed stats data for {func}: {stats_data}")
continue
cc, nc, tt, ct, callers = stats_data
if len(func) == 4:
file_name, line_num, func_name, class_name = func
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
new_func = (file_name, line_num, new_func_name)
else:
new_func = func # Keep as is if already in correct format
new_callers = {}
callers_items = list(callers.items())
for caller_func, count in callers_items:
if isinstance(caller_func, tuple):
if len(caller_func) == 4:
caller_file, caller_line, caller_name, caller_class = caller_func
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
new_caller_func = (caller_file, caller_line, caller_new_name)
else:
new_caller_func = caller_func
else:
console.print(f"Unexpected caller format: {caller_func}")
new_caller_func = str(caller_func)
new_callers[new_caller_func] = count
# Store with new format
new_stats[new_func] = (cc, nc, tt, ct, new_callers)
except Exception as e: # noqa: BLE001
console.print(f"Error converting stats for {func}: {e}")
continue
timings_items = list(self.timings.items())
for func, timing_data in timings_items:
try:
if len(timing_data) != 5:
console.print(f"Skipping malformed timing data for {func}: {timing_data}")
continue
cc, ns, tt, ct, callers = timing_data
if len(func) == 4:
file_name, line_num, func_name, class_name = func
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
new_func = (file_name, line_num, new_func_name)
else:
new_func = func
new_callers = {}
callers_items = list(callers.items())
for caller_func, count in callers_items:
if isinstance(caller_func, tuple):
if len(caller_func) == 4:
caller_file, caller_line, caller_name, caller_class = caller_func
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
new_caller_func = (caller_file, caller_line, caller_new_name)
else:
new_caller_func = caller_func
else:
console.print(f"Unexpected caller format: {caller_func}")
new_caller_func = str(caller_func)
new_callers[new_caller_func] = count
new_timings[new_func] = (cc, ns, tt, ct, new_callers)
except Exception as e: # noqa: BLE001
console.print(f"Error converting timings for {func}: {e}")
continue
self.stats = new_stats
self.timings = new_timings
self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values())
total_calls = sum(cc for cc, _, _, _, _ in self.stats.values())
total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values())
summary = Text.assemble(
f"{total_calls:,} function calls ",
("(" + f"{total_primitive:,} primitive calls" + ")", "dim"),
f" in {self.total_tt / 1e6:.3f}milliseconds",
)
console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False)))
table = Table(
show_header=True,
header_style="bold magenta",
border_style="blue",
title="[bold]Function Profile[/bold] (ordered by internal time)",
title_style="cyan",
caption=f"Showing top 25 of {len(self.stats)} functions",
)
table.add_column("Calls", justify="right", style="green", width=10)
table.add_column("Time (ms)", justify="right", style="cyan", width=10)
table.add_column("Per Call", justify="right", style="cyan", width=10)
table.add_column("Cum (ms)", justify="right", style="yellow", width=10)
table.add_column("Cum/Call", justify="right", style="yellow", width=10)
table.add_column("Function", style="blue")
sorted_stats = sorted(
((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3),
key=lambda x: x[1][2], # Sort by tt (internal time)
reverse=True,
)[:25] # Limit to top 25
# Format and add each row to the table
for func, (cc, nc, tt, ct, _) in sorted_stats:
filename, lineno, funcname = func
# Format calls - show recursive format if different
calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}"
# Convert to milliseconds
tt_ms = tt / 1e6
ct_ms = ct / 1e6
# Calculate per-call times
per_call = tt_ms / cc if cc > 0 else 0
cum_per_call = ct_ms / nc if nc > 0 else 0
base_filename = Path(filename).name
file_link = f"[link=file://{filename}]{base_filename}[/link]"
table.add_row(
calls_str,
f"{tt_ms:.3f}",
f"{per_call:.3f}",
f"{ct_ms:.3f}",
f"{cum_per_call:.3f}",
f"{funcname} [dim]({file_link}:{lineno})[/dim]",
) )
if n > 0:
times_index += 1
new_stats.append(replaced)
console.print(Align.center(table)) console.print("\n".join(new_stats))
except Exception as e: # noqa: BLE001 def make_pstats_compatible(self):
console.print(f"[bold red]Error in stats processing:[/bold red] {e}")
console.print(f"Traced {self.trace_count:,} function calls")
self.total_tt = 0
def make_pstats_compatible(self) -> None:
# delete the extra class_name item from the function tuple # delete the extra class_name item from the function tuple
self.files = [] self.files = []
self.top_level = [] self.top_level = []
@ -693,33 +552,36 @@ class Tracer:
self.stats = new_stats self.stats = new_stats
self.timings = new_timings self.timings = new_timings
def dump_stats(self, file: str) -> None: def dump_stats(self, file):
with Path(file).open("wb") as f: with open(file, "wb") as f:
self.create_stats()
marshal.dump(self.stats, f) marshal.dump(self.stats, f)
def create_stats(self) -> None: def create_stats(self):
self.simulate_cmd_complete() self.simulate_cmd_complete()
self.snapshot_stats() self.snapshot_stats()
def snapshot_stats(self) -> None: def snapshot_stats(self):
self.stats = {} self.stats = {}
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items(): for func, (cc, ns, tt, ct, callers) in self.timings.items():
callers = caller_dict.copy() callers = callers.copy()
nc = 0 nc = 0
for callcnt in callers.values(): for callcnt in callers.values():
nc += callcnt nc += callcnt
self.stats[func] = cc, nc, tt, ct, callers self.stats[func] = cc, nc, tt, ct, callers
def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None: def runctx(self, cmd, globals, locals):
self.__enter__() self.__enter__()
try: try:
exec(cmd, global_vars, local_vars) # noqa: S102 exec(cmd, globals, locals)
finally: finally:
self.__exit__(None, None, None) self.__exit__(None, None, None)
return self return self
def main() -> ArgumentParser: def main():
from argparse import ArgumentParser
parser = ArgumentParser(allow_abbrev=False) parser = ArgumentParser(allow_abbrev=False)
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True) parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None) parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
@ -776,13 +638,16 @@ def main() -> ArgumentParser:
"__cached__": None, "__cached__": None,
} }
try: try:
Tracer( tracer = Tracer(
output=args.outfile, output=args.outfile,
functions=args.only_functions, functions=args.only_functions,
max_function_count=args.max_function_count, max_function_count=args.max_function_count,
timeout=args.tracer_timeout, timeout=args.tracer_timeout,
config_file_path=args.codeflash_config, config_file_path=args.codeflash_config,
).runctx(code, globs, None) )
tracer.runctx(code, globs, None)
print(tracer.functions)
except BrokenPipeError as exc: except BrokenPipeError as exc:
# Prevent "Exception ignored" during interpreter shutdown. # Prevent "Exception ignored" during interpreter shutdown.

View file

@ -29,7 +29,6 @@ class TestType(Enum):
REPLAY_TEST = 4 REPLAY_TEST = 4
CONCOLIC_COVERAGE_TEST = 5 CONCOLIC_COVERAGE_TEST = 5
INIT_STATE_TEST = 6 INIT_STATE_TEST = 6
BENCHMARK_TEST = 7
def to_name(self) -> str: def to_name(self) -> str:
if self == TestType.INIT_STATE_TEST: if self == TestType.INIT_STATE_TEST:
@ -40,7 +39,6 @@ class TestType(Enum):
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests", TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests", TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests", TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
TestType.BENCHMARK_TEST: "📏 Benchmark Tests",
} }
return names[self] return names[self]

View file

@ -75,3 +75,4 @@ class TestConfig:
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path) # or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
concolic_test_root_dir: Optional[Path] = None concolic_test_root_dir: Optional[Path] = None
pytest_cmd: str = "pytest" pytest_cmd: str = "pytest"
benchmark_tests_root: Optional[Path] = None

View file

@ -69,7 +69,7 @@ exclude = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9" python = ">=3.9"
unidiff = ">=0.7.4" unidiff = ">=0.7.4"
pytest = ">=7.0.0,<8.3.4" pytest = ">=7.0.0"
gitpython = ">=3.1.31" gitpython = ">=3.1.31"
libcst = ">=1.0.1" libcst = ">=1.0.1"
jedi = ">=0.19.1" jedi = ">=0.19.1"
@ -92,7 +92,6 @@ rich = ">=13.8.1"
lxml = ">=5.3.0" lxml = ">=5.3.0"
crosshair-tool = ">=0.0.78" crosshair-tool = ">=0.0.78"
coverage = ">=7.6.4" coverage = ">=7.6.4"
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
[tool.poetry.group.dev] [tool.poetry.group.dev]
optional = true optional = true
@ -120,7 +119,7 @@ types-gevent = "^24.11.0.20241230"
types-greenlet = "^3.1.0.20241221" types-greenlet = "^3.1.0.20241221"
types-pexpect = "^4.9.0.20241208" types-pexpect = "^4.9.0.20241208"
types-unidiff = "^0.7.0.20240505" types-unidiff = "^0.7.0.20240505"
uv = ">=0.6.2" sqlalchemy = "^2.0.38"
[tool.poetry.build] [tool.poetry.build]
script = "codeflash/update_license_version.py" script = "codeflash/update_license_version.py"
@ -152,7 +151,7 @@ warn_required_dynamic_aliases = true
line-length = 120 line-length = 120
fix = true fix = true
show-fixes = true show-fixes = true
exclude = ["code_to_optimize/", "pie_test_set/", "tests/"] exclude = ["code_to_optimize/", "pie_test_set/"]
[tool.ruff.lint] [tool.ruff.lint]
select = ["ALL"] select = ["ALL"]
@ -164,11 +163,10 @@ ignore = [
"D103", "D103",
"D105", "D105",
"D107", "D107",
"D203", # incorrect-blank-line-before-class (incompatible with D211)
"D213", # multi-line-summary-second-line (incompatible with D212)
"S101", "S101",
"S603", "S603",
"S607", "S607",
"ANN101",
"COM812", "COM812",
"FIX002", "FIX002",
"PLR0912", "PLR0912",
@ -177,14 +175,13 @@ ignore = [
"TD002", "TD002",
"TD003", "TD003",
"TD004", "TD004",
"PLR2004", "PLR2004"
"UP007" # remove once we drop 3.9 support.
] ]
[tool.ruff.lint.flake8-type-checking] [tool.ruff.lint.flake8-type-checking]
strict = true strict = true
runtime-evaluated-base-classes = ["pydantic.BaseModel"] runtime-evaluated-base-classes = ["pydantic.BaseModel"]
runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"] runtime-evaluated-decorators = ["pydantic.validate_call"]
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = [ classmethod-decorators = [
@ -192,9 +189,6 @@ classmethod-decorators = [
"pydantic.validator", "pydantic.validator",
] ]
[tool.ruff.lint.isort]
split-on-trailing-comma = false
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true
skip-magic-trailing-comma = true skip-magic-trailing-comma = true
@ -217,13 +211,13 @@ initial-content = """
[tool.codeflash] [tool.codeflash]
module-root = "codeflash" # All paths are relative to this pyproject.toml's directory.
tests-root = "tests" module-root = "code_to_optimize"
tests-root = "code_to_optimize/tests"
benchmarks-root = "code_to_optimize/tests/pytest/benchmarks"
test-framework = "pytest" test-framework = "pytest"
formatter-cmds = [ ignore-paths = []
"uvx ruff check --exit-zero --fix $file", formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
"uvx ruff format $file",
]
[build-system] [build-system]

View file

@ -0,0 +1,8 @@
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from pathlib import Path
def test_trace_benchmarks():
# Test the trace_benchmarks function
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks"
trace_benchmarks_pytest(benchmarks_root, project_root, ["sorter"])

View file

@ -3,6 +3,7 @@ import tempfile
from pathlib import Path from pathlib import Path
from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.verification.test_results import TestType
from codeflash.verification.verification_utils import TestConfig from codeflash.verification.verification_utils import TestConfig
@ -21,7 +22,7 @@ def test_unit_test_discovery_pytest():
def test_benchmark_test_discovery_pytest(): def test_benchmark_test_discovery_pytest():
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize" project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest" tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
test_config = TestConfig( test_config = TestConfig(
tests_root=tests_path, tests_root=tests_path,
project_root_path=project_path, project_root_path=project_path,
@ -29,9 +30,10 @@ def test_benchmark_test_discovery_pytest():
tests_project_rootdir=tests_path.parent, tests_project_rootdir=tests_path.parent,
) )
tests = discover_unit_tests(test_config) tests = discover_unit_tests(test_config)
print(tests)
assert len(tests) > 0 assert len(tests) > 0
# print(tests) assert 'bubble_sort.sorter' in tests
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
assert benchmark_tests == 1
def test_unit_test_discovery_unittest(): def test_unit_test_discovery_unittest():