initial implementation for tracing benchmarks using a plugin, and projecting speedup

This commit is contained in:
Alvin Ryanputra 2025-02-27 16:25:07 -08:00
parent a73b541159
commit 965e2c818c
22 changed files with 790 additions and 571 deletions

View file

@ -0,0 +1,28 @@
from code_to_optimize.bubble_sort import sorter
def calculate_pairwise_products(arr):
"""
Calculate the average of all pairwise products in the array.
"""
sum_of_products = 0
count = 0
for i in range(len(arr)):
for j in range(len(arr)):
if i != j:
sum_of_products += arr[i] * arr[j]
count += 1
# The average of all pairwise products
return sum_of_products / count if count > 0 else 0
def compute_and_sort(arr):
# Compute pairwise sums average
pairwise_average = calculate_pairwise_products(arr)
# Call sorter function
sorter(arr.copy())
return pairwise_average

View file

@ -1,6 +1,13 @@
import pytest
from code_to_optimize.bubble_sort import sorter
def test_sort(benchmark):
result = benchmark(sorter, list(reversed(range(5000))))
assert result == list(range(5000))
# This should not be picked up as a benchmark test
def test_sort2():
result = sorter(list(reversed(range(5000))))
assert result == list(range(5000))

View file

@ -0,0 +1,8 @@
from code_to_optimize.process_and_bubble_sort import compute_and_sort
from code_to_optimize.bubble_sort2 import sorter
def test_compute_and_sort(benchmark):
result = benchmark(compute_and_sort, list(reversed(range(5000))))
assert result == 6247083.5
def test_no_func(benchmark):
benchmark(sorter, list(reversed(range(5000))))

View file

View file

@ -0,0 +1,112 @@
import sqlite3
from pathlib import Path
from typing import Dict, Set
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
def get_function_benchmark_timings(trace_dir: Path, all_functions_to_optimize: list[FunctionToOptimize]) -> dict[str, dict[str, float]]:
"""Process all trace files in the given directory and extract timing data for the specified functions.
Args:
trace_dir: Path to the directory containing .trace files
all_functions_to_optimize: Set of FunctionToOptimize objects representing functions to include
Returns:
A nested dictionary where:
- Outer keys are function qualified names with file name
- Inner keys are benchmark names (trace filename without .trace extension)
- Values are function timing in milliseconds
"""
# Create a mapping of (filename, function_name, class_name) -> qualified_name for efficient lookups
function_lookup = {}
function_benchmark_timings = {}
for func in all_functions_to_optimize:
qualified_name = func.qualified_name_with_file_name
# Extract components (assumes Path.name gives only filename without directory)
filename = func.file_path
function_name = func.function_name
# Get class name if there's a parent
class_name = func.parents[0].name if func.parents else None
# Store in lookup dictionary
key = (filename, function_name, class_name)
function_lookup[key] = qualified_name
function_benchmark_timings[qualified_name] = {}
# Find all .trace files in the directory
trace_files = list(trace_dir.glob("*.trace"))
for trace_file in trace_files:
# Extract benchmark name from filename (without .trace)
benchmark_name = trace_file.stem
# Connect to the trace database
conn = sqlite3.connect(trace_file)
cursor = conn.cursor()
# For each function we're interested in, query the database directly
for (filename, function_name, class_name), qualified_name in function_lookup.items():
# Adjust query based on whether we have a class name
if class_name:
cursor.execute(
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND class_name = ?",
(f"%{filename}", function_name, class_name)
)
else:
cursor.execute(
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND (class_name IS NULL OR class_name = '')",
(f"%{filename}", function_name)
)
result = cursor.fetchone()
if result:
time_ns = result[0]
function_benchmark_timings[qualified_name][benchmark_name] = time_ns / 1e6 # Convert to milliseconds
conn.close()
return function_benchmark_timings
def get_benchmark_timings(trace_dir: Path) -> dict[str, float]:
"""Extract total benchmark timings from trace files.
Args:
trace_dir: Path to the directory containing .trace files
Returns:
A dictionary mapping benchmark names to their total execution time in milliseconds.
"""
benchmark_timings = {}
# Find all .trace files in the directory
trace_files = list(trace_dir.glob("*.trace"))
for trace_file in trace_files:
# Extract benchmark name from filename (without .trace extension)
benchmark_name = trace_file.stem
# Connect to the trace database
conn = sqlite3.connect(trace_file)
cursor = conn.cursor()
# Query the total_time table for the benchmark's total execution time
try:
cursor.execute("SELECT time_ns FROM total_time")
result = cursor.fetchone()
if result:
time_ns = result[0]
# Convert nanoseconds to milliseconds
benchmark_timings[benchmark_name] = time_ns / 1e6
except sqlite3.OperationalError:
# Handle case where total_time table might not exist
print(f"Warning: Could not get total time for benchmark {benchmark_name}")
conn.close()
return benchmark_timings

View file

@ -0,0 +1,79 @@
import pytest
from codeflash.tracer import Tracer
from pathlib import Path
class CodeFlashPlugin:
@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
parser.addoption(
"--functions",
action="store",
default="",
help="Comma-separated list of additional functions to trace"
)
parser.addoption(
"--benchmarks-root",
action="store",
default=".",
help="Root directory for benchmarks"
)
@staticmethod
def pytest_plugin_registered(plugin, manager):
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
manager.unregister(plugin)
@staticmethod
def pytest_collection_modifyitems(config, items):
if not config.getoption("--codeflash-trace"):
return
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
continue
item.add_marker(skip_no_benchmark)
@staticmethod
@pytest.fixture
def benchmark(request):
if not request.config.getoption("--codeflash-trace"):
return None
class Benchmark:
def __call__(self, func, *args, **kwargs):
func_name = func.__name__
test_name = request.node.name
additional_functions = request.config.getoption("--functions").split(",")
trace_functions = [f for f in additional_functions if f]
print("Tracing functions: ", trace_functions)
# Get benchmarks root directory from command line option
benchmarks_root = Path(request.config.getoption("--benchmarks-root"))
# Create .trace directory if it doesn't exist
trace_dir = benchmarks_root / '.codeflash_trace'
trace_dir.mkdir(exist_ok=True)
# Set output path to the .trace directory
output_path = trace_dir / f"{test_name}.trace"
tracer = Tracer(
output=str(output_path), # Convert Path to string for Tracer
functions=trace_functions,
max_function_count=256
)
with tracer:
result = func(*args, **kwargs)
return result
return Benchmark()

View file

@ -0,0 +1,15 @@
import sys
from plugin.plugin import CodeFlashPlugin
benchmarks_root = sys.argv[1]
function_list = sys.argv[2]
if __name__ == "__main__":
import pytest
try:
exitcode = pytest.main(
[benchmarks_root, "--benchmarks-root", benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "--functions", function_list], plugins=[CodeFlashPlugin()]
)
except Exception as e:
print(f"Failed to collect tests: {e!s}")
exitcode = -1

View file

@ -0,0 +1,20 @@
from __future__ import annotations
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from pathlib import Path
import subprocess
def trace_benchmarks_pytest(benchmarks_root: Path, project_root: Path, function_list: list[str] = []) -> None:
result = subprocess.run(
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
str(benchmarks_root),
",".join(function_list)
],
cwd=project_root,
check=False,
capture_output=True,
text=True,
)
print("stdout:", result.stdout)
print("stderr:", result.stderr)

View file

@ -62,6 +62,10 @@ def parse_args() -> Namespace:
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks")
parser.add_argument(
"--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located."
)
args: Namespace = parser.parse_args()
return process_and_validate_cmd_args(args)
@ -116,6 +120,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"disable_telemetry",
"disable_imports_sorting",
"git_remote",
"benchmarks_root"
]
for key in supported_keys:
if key in pyproject_config and (
@ -127,9 +132,10 @@ def process_pyproject_config(args: Namespace) -> Namespace:
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None, "--tests-root must be specified"
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
if env_utils.get_pr_number() is not None:
assert env_utils.ensure_codeflash_api_key(), (
if args.benchmark:
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
"Codeflash API key not found. When running in a Github Actions Context, provide the "
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
@ -137,13 +143,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
f"Here's a direct link: {get_github_secrets_page_url()}\n"
"Exiting..."
)
repo = git.Repo(search_parent_directories=True)
owner, repo_name = get_repo_owner_and_name(repo)
require_github_app_or_exit(owner, repo_name)
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
normalized_ignore_paths = []
for path in args.ignore_paths:

View file

@ -107,8 +107,6 @@ def discover_tests_pytest(
test_type = TestType.REPLAY_TEST
elif "test_concolic_coverage" in test["test_file"]:
test_type = TestType.CONCOLIC_COVERAGE_TEST
elif test["test_type"] == "benchmark": # New condition for benchmark tests
test_type = TestType.BENCHMARK_TEST
else:
test_type = TestType.EXISTING_UNIT_TEST

View file

@ -121,7 +121,6 @@ class FunctionToOptimize:
method extends this with the module name from the project root.
"""
function_name: str
file_path: Path
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
@ -145,6 +144,11 @@ class FunctionToOptimize:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
@property
def qualified_name_with_file_name(self) -> str:
class_name = self.parents[0].name if self.parents else None
return f"{self.file_path}:{(class_name + ':' if class_name else '')}{self.function_name}"
def get_functions_to_optimize(
optimize_all: str | None,

View file

@ -0,0 +1,54 @@
import sys
from typing import Any
# This script should not have any relation to the codeflash package, be careful with imports
cwd = sys.argv[1]
tests_root = sys.argv[2]
pickle_path = sys.argv[3]
collected_tests = []
pytest_rootdir = None
sys.path.insert(1, str(cwd))
class PytestCollectionPlugin:
def pytest_collection_finish(self, session) -> None:
global pytest_rootdir
collected_tests.extend(session.items)
pytest_rootdir = session.config.rootdir
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
test_results = []
for test in pytest_tests:
test_class = None
if test.cls:
test_class = test.parent.name
# Determine if this is a benchmark test by checking for the benchmark fixture
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
test_type = 'benchmark' if is_benchmark else 'regular'
test_results.append({
"test_file": str(test.path),
"test_class": test_class,
"test_function": test.name,
"test_type": test_type
})
return test_results
if __name__ == "__main__":
import pytest
try:
exitcode = pytest.main(
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
)
except Exception as e:
print(f"Failed to collect tests: {e!s}")
exitcode = -1
tests = parse_pytest_collection_results(collected_tests)
import pickle
with open(pickle_path, "wb") as f:
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)

View file

@ -29,17 +29,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s
test_class = None
if test.cls:
test_class = test.parent.name
# Determine if this is a benchmark test by checking for the benchmark fixture
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
test_type = 'benchmark' if is_benchmark else 'regular'
test_results.append({
"test_file": str(test.path),
"test_class": test_class,
"test_function": test.name,
"test_type": test_type
})
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
return test_results

View file

@ -7,7 +7,7 @@ import shutil
import subprocess
import time
import uuid
from collections import defaultdict, deque
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
@ -21,12 +21,12 @@ from rich.tree import Tree
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
cleanup_paths,
file_name_from_test_module_name,
get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path,
)
from codeflash.code_utils.config_consts import (
@ -37,7 +37,6 @@ from codeflash.code_utils.config_consts import (
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
@ -49,6 +48,7 @@ from codeflash.models.models import (
BestOptimization,
CodeOptimizationContext,
FunctionCalledInTest,
FunctionParent,
GeneratedTests,
GeneratedTestsList,
OptimizationSet,
@ -57,9 +57,8 @@ from codeflash.models.models import (
TestFile,
TestFiles,
TestingMode,
TestResults,
TestType,
)
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.explanation import Explanation
@ -67,15 +66,18 @@ from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
from codeflash.verification.test_results import TestResults, TestType
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
if TYPE_CHECKING:
from argparse import Namespace
import numpy as np
import numpy.typing as npt
from codeflash.either import Result
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
from codeflash.verification.verification_utils import TestConfig
@ -90,6 +92,8 @@ class FunctionOptimizer:
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None,
aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
total_benchmark_timings: dict[str, float] | None = None,
args: Namespace | None = None,
) -> None:
self.project_root = test_cfg.project_root_path
@ -118,6 +122,9 @@ class FunctionOptimizer:
self.function_trace_id: str = str(uuid.uuid4())
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
def optimize_function(self) -> Result[BestOptimization, str]:
should_run_experiment = self.experiment_id is not None
logger.debug(f"Function Trace ID: {self.function_trace_id}")
@ -134,10 +141,19 @@ class FunctionOptimizer:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
if has_any_async_functions(code_context.read_writable_code):
return Failure("Codeflash does not support async functions in the code to optimize.")
logger.info("Code to be optimized:")
code_print(code_context.read_writable_code)
for module_abspath, helper_code_source in original_helper_code.items():
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
helper_code_source,
code_context.code_to_optimize_with_helpers,
module_abspath,
self.function_to_optimize.file_path,
self.args.project_root,
)
generated_test_paths = [
get_test_file_path(
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
@ -156,7 +172,7 @@ class FunctionOptimizer:
transient=True,
):
generated_results = self.generate_tests_and_optimizations(
testgen_context_code=code_context.testgen_context_code,
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
read_writable_code=code_context.read_writable_code,
read_only_context_code=code_context.read_only_context_code,
helper_functions=code_context.helper_functions,
@ -232,11 +248,10 @@ class FunctionOptimizer:
):
cleanup_paths(paths_to_cleanup)
return Failure("The threshold for test coverage was not met.")
# request for new optimizations but don't block execution, check for completion later
# adding to control and experiment set but with same traceid
best_optimization = None
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
if candidates is None:
continue
@ -270,6 +285,20 @@ class FunctionOptimizer:
function_name=function_to_optimize_qualified_name,
file_path=self.function_to_optimize.file_path,
)
speedup = explanation.speedup # eg. 1.2 means 1.2x faster
if self.args.benchmark:
fto_benchmark_timings = self.function_benchmark_timings[self.function_to_optimize.qualified_name_with_file_name]
for benchmark_name, og_benchmark_timing in fto_benchmark_timings.items():
print(f"Calculating speedup for benchmark {benchmark_name}")
total_benchmark_timing = self.total_benchmark_timings[benchmark_name]
# find out expected new benchmark timing, then calculate how much total benchmark was sped up. print out intermediate values
expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + og_benchmark_timing / speedup
print(f"Expected new benchmark timing: {expected_new_benchmark_timing}")
print(f"Original benchmark timing: {total_benchmark_timing}")
print(f"Benchmark speedup: {total_benchmark_timing / expected_new_benchmark_timing}")
speedup = total_benchmark_timing / expected_new_benchmark_timing
print(f"Speedup: {speedup}")
self.log_successful_optimization(explanation, generated_tests)
@ -359,37 +388,11 @@ class FunctionOptimizer:
f"{self.function_to_optimize.qualified_name}"
)
console.rule()
candidates = deque(candidates)
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_line_profile_results = executor.submit(
self.aiservice_client.optimize_python_code_line_profiler,
source_code=code_context.read_writable_code,
dependency_code=code_context.read_only_context_code,
trace_id=self.function_trace_id,
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
num_candidates=10,
experiment_metadata=None,
)
try:
candidate_index = 0
done = False
original_len = len(candidates)
while candidates:
# for candidate_index, candidate in enumerate(candidates, start=1):
done = True if future_line_profile_results is None else future_line_profile_results.done()
if done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len+= len(line_profile_results)
logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
candidate_index += 1
candidate = candidates.popleft()
for candidate_index, candidate in enumerate(candidates, start=1):
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
code_print(candidate.source_code)
try:
did_update = self.replace_function_and_helpers_with_optimized_code(
@ -404,9 +407,7 @@ class FunctionOptimizer:
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)
self.write_code_and_helpers(
self.function_to_optimize_source_code,
original_helper_code,
self.function_to_optimize.file_path,
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
continue
@ -469,7 +470,6 @@ class FunctionOptimizer:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
@ -575,6 +575,50 @@ class FunctionOptimizer:
return did_update
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
self.function_to_optimize, self.project_root, code_to_optimize
)
if self.function_to_optimize.parents:
function_class = self.function_to_optimize.parents[0].name
same_class_helper_methods = [
df
for df in helper_functions
if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
]
optimizable_methods = [
FunctionToOptimize(
df.qualified_name.split(".")[-1],
df.file_path,
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
None,
None,
)
for df in same_class_helper_methods
] + [self.function_to_optimize]
dedup_optimizable_methods = []
added_methods = set()
for method in reversed(optimizable_methods):
if f"{method.file_path}.{method.qualified_name}" not in added_methods:
dedup_optimizable_methods.append(method)
added_methods.add(f"{method.file_path}.{method.qualified_name}")
if len(dedup_optimizable_methods) > 1:
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
self.function_to_optimize_source_code,
code_to_optimize_with_helpers,
self.function_to_optimize.file_path,
self.function_to_optimize.file_path,
self.project_root,
helper_functions,
)
try:
new_code_ctx = code_context_extractor.get_code_optimization_context(
self.function_to_optimize, self.project_root
@ -584,7 +628,7 @@ class FunctionOptimizer:
return Success(
CodeOptimizationContext(
testgen_context_code=new_code_ctx.testgen_context_code,
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
@ -686,7 +730,7 @@ class FunctionOptimizer:
def generate_tests_and_optimizations(
self,
testgen_context_code: str,
code_to_optimize_with_helpers: str,
read_writable_code: str,
read_only_context_code: str,
helper_functions: list[FunctionSource],
@ -701,7 +745,7 @@ class FunctionOptimizer:
# Submit the test generation task as future
future_tests = self.generate_and_instrument_tests(
executor,
testgen_context_code,
code_to_optimize_with_helpers,
[definition.fully_qualified_name for definition in helper_functions],
generated_test_paths,
generated_perf_test_paths,
@ -790,7 +834,6 @@ class FunctionOptimizer:
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
@ -831,31 +874,11 @@ class FunctionOptimizer:
)
console.rule()
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(coverage_results, self.args.test_framework):
if not coverage_critic(
coverage_results, self.args.test_framework
):
return Failure("The threshold for test coverage was not met.")
if test_framework == "pytest":
try:
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
line_profile_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LINE_PROFILE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME,
enable_coverage=False,
code_context=code_context,
line_profiler_output_file=line_profiler_output_file,
)
finally:
# Remove codeflash capture
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if line_profile_results["str_out"] == "":
logger.warning(
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
console.rule()
benchmarking_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
@ -894,6 +917,7 @@ class FunctionOptimizer:
)
console.rule()
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
functions_to_remove = [
result.id.test_function_name
@ -927,7 +951,6 @@ class FunctionOptimizer:
benchmarking_test_results=benchmarking_results,
runtime=total_timing,
coverage_results=coverage_results,
line_profile_results=line_profile_results,
),
functions_to_remove,
)
@ -1063,62 +1086,47 @@ class FunctionOptimizer:
pytest_max_loops: int = 100_000,
code_context: CodeOptimizationContext | None = None,
unittest_loop_index: int | None = None,
line_profiler_output_file: Path | None = None,
) -> tuple[TestResults | dict, CoverageData | None]:
) -> tuple[TestResults, CoverageData | None]:
coverage_database_file = None
coverage_config_file = None
try:
if testing_type == TestingMode.BEHAVIOR:
result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests(
result_file_path, run_result, coverage_database_file = run_behavioral_tests(
test_files,
test_framework=self.test_cfg.test_framework,
cwd=self.project_root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_cmd=self.test_cfg.pytest_cmd,
verbose=True,
enable_coverage=enable_coverage,
)
elif testing_type == TestingMode.LINE_PROFILE:
result_file_path, run_result = run_line_profile_tests(
test_files,
cwd=self.project_root,
test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_min_loops,
test_framework=self.test_cfg.test_framework,
line_profiler_output_file=line_profiler_output_file,
)
elif testing_type == TestingMode.PERFORMANCE:
result_file_path, run_result = run_benchmarking_tests(
test_files,
cwd=self.project_root,
test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_max_loops,
test_framework=self.test_cfg.test_framework,
)
else:
msg = f"Unexpected testing type: {testing_type}"
raise ValueError(msg)
raise ValueError(f"Unexpected testing type: {testing_type}")
except subprocess.TimeoutExpired:
logger.exception(
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
)
return TestResults(), None
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
logger.debug(
f"Nonzero return code {run_result.returncode} when running tests in "
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n"
f'Nonzero return code {run_result.returncode} when running tests in '
f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n'
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
# print(test_files)
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
test_files=test_files,
@ -1130,10 +1138,7 @@ class FunctionOptimizer:
source_file=self.function_to_optimize.file_path,
code_context=code_context,
coverage_database_file=coverage_database_file,
coverage_config_file=coverage_config_file,
)
else:
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
return results, coverage_results
def generate_and_instrument_tests(
@ -1163,3 +1168,4 @@ class FunctionOptimizer:
zip(generated_test_paths, generated_perf_test_paths)
)
]

View file

@ -8,7 +8,8 @@ from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.code_utils import get_run_tmp_file
@ -16,10 +17,12 @@ from codeflash.code_utils.static_analysis import analyze_imported_modules, get_f
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.either import is_successful
from codeflash.models.models import TestType, ValidCode
from codeflash.models.models import TestFiles, ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.test_results import TestType
from codeflash.verification.verification_utils import TestConfig
from codeflash.benchmarking.get_trace_info import get_function_benchmark_timings, get_benchmark_timings
if TYPE_CHECKING:
from argparse import Namespace
@ -50,6 +53,8 @@ class Optimizer:
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
total_benchmark_timings: dict[str, float] | None = None,
) -> FunctionOptimizer:
return FunctionOptimizer(
function_to_optimize=function_to_optimize,
@ -59,6 +64,8 @@ class Optimizer:
function_to_optimize_ast=function_to_optimize_ast,
aiservice_client=self.aiservice_client,
args=self.args,
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
)
def run(self) -> None:
@ -80,6 +87,23 @@ class Optimizer:
project_root=self.args.project_root,
module_root=self.args.module_root,
)
if self.args.benchmark:
all_functions_to_optimize = [
function
for functions_list in file_to_funcs_to_optimize.values()
for function in functions_list
]
logger.info(f"Tracing existing benchmarks for {len(all_functions_to_optimize)} functions")
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.project_root, [fto.qualified_name_with_file_name for fto in all_functions_to_optimize])
logger.info("Finished tracing existing benchmarks")
trace_dir = Path(self.args.benchmarks_root) / ".codeflash_trace"
function_benchmark_timings = get_function_benchmark_timings(trace_dir, all_functions_to_optimize)
print(function_benchmark_timings)
total_benchmark_timings = get_benchmark_timings(trace_dir)
print("Total benchmark timings:")
print(total_benchmark_timings)
# for function in fully_qualified_function_names:
optimizations_found: int = 0
function_iterator_count: int = 0
@ -93,6 +117,8 @@ class Optimizer:
logger.info("No functions found to optimize. Exiting…")
return
console.rule()
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}")
console.rule()
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
@ -136,7 +162,6 @@ class Optimizer:
validated_original_code[analysis.file_path] = ValidCode(
source_code=callee_original_code, normalized_code=normalized_callee_original_code
)
if has_syntax_error:
continue
@ -146,7 +171,7 @@ class Optimizer:
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
f"{function_to_optimize.qualified_name}"
)
console.rule()
if not (
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
@ -157,12 +182,17 @@ class Optimizer:
f"Skipping optimization."
)
continue
if self.args.benchmark:
function_optimizer = self.create_function_optimizer(
function_to_optimize,
function_to_optimize_ast,
function_to_tests,
validated_original_code[original_module_path].source_code,
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings, total_benchmark_timings
)
else:
function_optimizer = self.create_function_optimizer(
function_to_optimize, function_to_optimize_ast, function_to_tests,
validated_original_code[original_module_path].source_code
)
best_optimization = function_optimizer.optimize_function()
if is_successful(best_optimization):
optimizations_found += 1
@ -191,6 +221,7 @@ class Optimizer:
get_run_tmp_file.tmpdir.cleanup()
def run_with_args(args: Namespace) -> None:
optimizer = Optimizer(args)
optimizer.run()

View file

@ -18,21 +18,19 @@ import marshal
import os
import pathlib
import pickle
import re
import sqlite3
import sys
import threading
import time
from argparse import ArgumentParser
from collections import defaultdict
from copy import copy
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar
from types import FrameType
from typing import Any, ClassVar, List
import dill
import isort
from rich.align import Align
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from codeflash.cli_cmds.cli import project_root_from_module_root
from codeflash.cli_cmds.console import console
@ -42,34 +40,14 @@ from codeflash.discovery.functions_to_optimize import filter_files_optimized
from codeflash.tracing.replay_test import create_trace_replay_test
from codeflash.tracing.tracing_utils import FunctionModules
from codeflash.verification.verification_utils import get_test_file_path
if TYPE_CHECKING:
from types import FrameType, TracebackType
class FakeCode:
def __init__(self, filename: str, line: int, name: str) -> None:
self.co_filename = filename
self.co_line = line
self.co_name = name
self.co_firstlineno = 0
def __repr__(self) -> str:
return repr((self.co_filename, self.co_line, self.co_name, None))
class FakeFrame:
def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None:
self.f_code = code
self.f_back = prior
self.f_locals: dict = {}
# import warnings
# warnings.filterwarnings("ignore", category=dill.PickleWarning)
# warnings.filterwarnings("ignore", category=DeprecationWarning)
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
class Tracer:
"""Use this class as a 'with' context manager to trace a function call.
Traces function calls, input arguments, and profiling info.
"""Use this class as a 'with' context manager to trace a function call,
input arguments, and profiling info.
"""
def __init__(
@ -81,9 +59,7 @@ class Tracer:
max_function_count: int = 256,
timeout: int | None = None, # seconds
) -> None:
"""Use this class to trace function calls.
:param output: The path to the output trace file
""":param output: The path to the output trace file
:param functions: List of functions to trace. If None, trace all functions
:param disable: Disable the tracer if True
:param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered
@ -94,9 +70,7 @@ class Tracer:
if functions is None:
functions = []
if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1":
console.rule(
"Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red"
)
console.print("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE")
disable = True
self.disable = disable
if self.disable:
@ -111,7 +85,7 @@ class Tracer:
self.con = None
self.output_file = Path(output).resolve()
self.functions = functions
self.function_modules: list[FunctionModules] = []
self.function_modules: List[FunctionModules] = []
self.function_count = defaultdict(int)
self.current_file_path = Path(__file__).resolve()
self.ignored_qualified_functions = {
@ -121,10 +95,10 @@ class Tracer:
self.max_function_count = max_function_count
self.config, found_config_path = parse_config_file(config_file_path)
self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path)
console.rule(f"Project Root: {self.project_root}", style="bold blue")
print("project_root", self.project_root)
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
self.timeout = timeout
@ -145,44 +119,48 @@ class Tracer:
def __enter__(self) -> None:
if self.disable:
return
if getattr(Tracer, "used_once", False):
console.print(
"Codeflash: Tracer can only be used once per program run. "
"Please only enable the Tracer once. Skipping tracing this section."
)
self.disable = True
return
Tracer.used_once = True
# if getattr(Tracer, "used_once", False):
# console.print(
# "Codeflash: Tracer can only be used once per program run. "
# "Please only enable the Tracer once. Skipping tracing this section."
# )
# self.disable = True
# return
# Tracer.used_once = True
if pathlib.Path(self.output_file).exists():
console.rule("Removing existing trace file", style="bold red")
console.rule()
console.print("Codeflash: Removing existing trace file")
pathlib.Path(self.output_file).unlink(missing_ok=True)
self.con = sqlite3.connect(self.output_file, check_same_thread=False)
self.con = sqlite3.connect(self.output_file)
cur = self.con.cursor()
cur.execute("""PRAGMA synchronous = OFF""")
cur.execute("""PRAGMA journal_mode = WAL""")
# TODO: Check out if we need to export the function test name as well
cur.execute(
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
)
console.rule("Codeflash: Traced Program Output Begin", style="bold blue")
frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001
console.print("Codeflash: Tracing started!")
frame = sys._getframe(0) # Get this frame and simulate a call to it
self.dispatch["call"](self, frame, 0)
self.start_time = time.time()
sys.setprofile(self.trace_callback)
threading.setprofile(self.trace_callback)
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.disable:
return
sys.setprofile(None)
self.con.commit()
console.rule("Codeflash: Traced Program Output End", style="bold blue")
# Check if any functions were actually traced
if self.trace_count == 0:
self.con.close()
# Delete the trace file if no functions were traced
if self.output_file.exists():
self.output_file.unlink()
console.print("Codeflash: No functions were traced. Removing trace database.")
return
self.create_stats()
cur = self.con.cursor()
@ -226,13 +204,14 @@ class Tracer:
test_framework=self.config["test_framework"],
max_run_count=self.max_function_count,
)
function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
# Need a better way to store the replay test
# function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
function_path = self.file_being_called_from
test_file_path = get_test_file_path(
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
)
replay_test = isort.code(replay_test)
with Path(test_file_path).open("w", encoding="utf8") as file:
with open(test_file_path, "w", encoding="utf8") as file:
file.write(replay_test)
console.print(
@ -242,27 +221,25 @@ class Tracer:
overflow="ignore",
)
def tracer_logic(self, frame: FrameType, event: str) -> None:
def tracer_logic(self, frame: FrameType, event: str):
if event != "call":
return
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
if self.timeout is not None:
if (time.time() - self.start_time) > self.timeout:
sys.setprofile(None)
threading.setprofile(None)
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
return
code = frame.f_code
file_name = Path(code.co_filename).resolve()
# TODO : It currently doesn't log the last return call from the first function
if code.co_name in self.ignored_functions:
return
if not file_name.is_relative_to(self.project_root):
return
if not file_name.exists():
return
if self.functions and code.co_name not in self.functions:
return
# if self.functions:
# if code.co_name not in self.functions:
# return
class_name = None
arguments = frame.f_locals
try:
@ -274,12 +251,16 @@ class Tracer:
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
except: # noqa: E722
except:
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
return
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
if function_qualified_name in self.ignored_qualified_functions:
return
if self.functions and function_qualified_name not in self.functions:
return
if function_qualified_name not in self.function_count:
# seeing this function for the first time
self.function_count[function_qualified_name] = 0
@ -354,14 +335,17 @@ class Tracer:
self.next_insert = 1000
self.con.commit()
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None:
# profiler section
timer = self.timer
t = timer() - self.t - self.bias
if event == "c_call":
self.c_func_name = arg.__name__
prof_success = bool(self.dispatch[event](self, frame, t))
if self.dispatch[event](self, frame, t):
prof_success = True
else:
prof_success = False
# tracer section
self.tracer_logic(frame, event)
# measure the time as the last thing before return
@ -370,24 +354,13 @@ class Tracer:
else:
self.t = timer() - t # put back unrecorded delta
def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
"""Handle call events in the profiler."""
try:
# In multi-threaded contexts, we need to be more careful about frame comparisons
def trace_dispatch_call(self, frame, t):
if self.cur and frame.f_back is not self.cur[-2]:
# This happens when we're in a different thread
rpt, rit, ret, rfn, rframe, rcur = self.cur
# Only attempt to handle the frame mismatch if we have a valid rframe
if (
not isinstance(rframe, FakeFrame)
and hasattr(rframe, "f_back")
and hasattr(frame, "f_back")
and rframe.f_back is frame.f_back
):
if not isinstance(rframe, Tracer.fake_frame):
assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back)
self.trace_dispatch_return(rframe, 0)
# Get function information
assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3])
fcode = frame.f_code
arguments = frame.f_locals
class_name = None
@ -400,9 +373,8 @@ class Tracer:
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
except Exception: # noqa: BLE001, S110
except:
pass
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
self.cur = (t, 0, 0, fn, frame, self.cur)
timings = self.timings
@ -411,19 +383,16 @@ class Tracer:
timings[fn] = cc, ns + 1, tt, ct, callers
else:
timings[fn] = 0, 0, 0, 0, {}
return 1 # noqa: TRY300
except Exception: # noqa: BLE001
# Handle any errors gracefully
return 0
return 1
def trace_dispatch_exception(self, frame: FrameType, t: int) -> int:
def trace_dispatch_exception(self, frame, t):
rpt, rit, ret, rfn, rframe, rcur = self.cur
if (rframe is not frame) and rcur:
return self.trace_dispatch_return(rframe, t)
self.cur = rpt, rit + t, ret, rfn, rframe, rcur
return 1
def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int:
def trace_dispatch_c_call(self, frame, t):
fn = ("", 0, self.c_func_name, None)
self.cur = (t, 0, 0, fn, frame, self.cur)
timings = self.timings
@ -434,27 +403,15 @@ class Tracer:
timings[fn] = 0, 0, 0, 0, {}
return 1
def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
if not self.cur or not self.cur[-2]:
return 0
# In multi-threaded environments, frames can get mismatched
def trace_dispatch_return(self, frame, t):
if frame is not self.cur[-2]:
# Don't assert in threaded environments - frames can legitimately differ
if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back:
assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3])
self.trace_dispatch_return(self.cur[-2], 0)
else:
# We're in a different thread or context, can't continue with this frame
return 0
# Prefix "r" means part of the Returning or exiting frame.
# Prefix "p" means part of the Previous or Parent or older frame.
rpt, rit, ret, rfn, frame, rcur = self.cur
# Guard against invalid rcur (w threading)
if not rcur:
return 0
rit = rit + t
frame_total = rit + ret
@ -462,9 +419,6 @@ class Tracer:
self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur
timings = self.timings
if rfn not in timings:
# w threading, rfn can be missing
timings[rfn] = 0, 0, 0, 0, {}
cc, ns, tt, ct, callers = timings[rfn]
if not ns:
# This is the only occurrence of the function on the stack.
@ -486,7 +440,7 @@ class Tracer:
return 1
dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = {
dispatch: ClassVar[dict[str, callable]] = {
"call": trace_dispatch_call,
"exception": trace_dispatch_exception,
"return": trace_dispatch_return,
@ -495,13 +449,32 @@ class Tracer:
"c_return": trace_dispatch_return,
}
def simulate_call(self, name: str) -> None:
code = FakeCode("profiler", 0, name)
pframe = self.cur[-2] if self.cur else None
frame = FakeFrame(code, pframe)
class fake_code:
def __init__(self, filename, line, name):
self.co_filename = filename
self.co_line = line
self.co_name = name
self.co_firstlineno = 0
def __repr__(self):
return repr((self.co_filename, self.co_line, self.co_name, None))
class fake_frame:
def __init__(self, code, prior):
self.f_code = code
self.f_back = prior
self.f_locals = {}
def simulate_call(self, name):
code = self.fake_code("profiler", 0, name)
if self.cur:
pframe = self.cur[-2]
else:
pframe = None
frame = self.fake_frame(code, pframe)
self.dispatch["call"](self, frame, 0)
def simulate_cmd_complete(self) -> None:
def simulate_cmd_complete(self):
get_time = self.timer
t = get_time() - self.t
while self.cur[-1]:
@ -511,174 +484,60 @@ class Tracer:
t = 0
self.t = get_time() - t
def print_stats(self, sort: str | int | tuple = -1) -> None:
if not self.stats:
console.print("Codeflash: No stats available to print")
self.total_tt = 0
return
def print_stats(self, sort=-1):
import pstats
if not isinstance(sort, tuple):
sort = (sort,)
# First, convert stats to make them pstats-compatible
try:
# Initialize empty collections for pstats
self.files = []
self.top_level = []
# Create entirely new dictionaries instead of modifying existing ones
new_stats = {}
new_timings = {}
# Convert stats dictionary
stats_items = list(self.stats.items())
for func, stats_data in stats_items:
try:
# Make sure we have 5 elements in stats_data
if len(stats_data) != 5:
console.print(f"Skipping malformed stats data for {func}: {stats_data}")
continue
cc, nc, tt, ct, callers = stats_data
if len(func) == 4:
file_name, line_num, func_name, class_name = func
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
new_func = (file_name, line_num, new_func_name)
else:
new_func = func # Keep as is if already in correct format
new_callers = {}
callers_items = list(callers.items())
for caller_func, count in callers_items:
if isinstance(caller_func, tuple):
if len(caller_func) == 4:
caller_file, caller_line, caller_name, caller_class = caller_func
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
new_caller_func = (caller_file, caller_line, caller_new_name)
else:
new_caller_func = caller_func
else:
console.print(f"Unexpected caller format: {caller_func}")
new_caller_func = str(caller_func)
new_callers[new_caller_func] = count
# Store with new format
new_stats[new_func] = (cc, nc, tt, ct, new_callers)
except Exception as e: # noqa: BLE001
console.print(f"Error converting stats for {func}: {e}")
continue
timings_items = list(self.timings.items())
for func, timing_data in timings_items:
try:
if len(timing_data) != 5:
console.print(f"Skipping malformed timing data for {func}: {timing_data}")
continue
cc, ns, tt, ct, callers = timing_data
if len(func) == 4:
file_name, line_num, func_name, class_name = func
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
new_func = (file_name, line_num, new_func_name)
else:
new_func = func
new_callers = {}
callers_items = list(callers.items())
for caller_func, count in callers_items:
if isinstance(caller_func, tuple):
if len(caller_func) == 4:
caller_file, caller_line, caller_name, caller_class = caller_func
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
new_caller_func = (caller_file, caller_line, caller_new_name)
else:
new_caller_func = caller_func
else:
console.print(f"Unexpected caller format: {caller_func}")
new_caller_func = str(caller_func)
new_callers[new_caller_func] = count
new_timings[new_func] = (cc, ns, tt, ct, new_callers)
except Exception as e: # noqa: BLE001
console.print(f"Error converting timings for {func}: {e}")
continue
self.stats = new_stats
self.timings = new_timings
self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values())
total_calls = sum(cc for cc, _, _, _, _ in self.stats.values())
total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values())
summary = Text.assemble(
f"{total_calls:,} function calls ",
("(" + f"{total_primitive:,} primitive calls" + ")", "dim"),
f" in {self.total_tt / 1e6:.3f}milliseconds",
# The following code customizes the default printing behavior to
# print in milliseconds.
s = StringIO()
stats_obj = pstats.Stats(copy(self), stream=s)
stats_obj.strip_dirs().sort_stats(*sort).print_stats(100)
self.total_tt = stats_obj.total_tt
console.print("total_tt", self.total_tt)
raw_stats = s.getvalue()
m = re.search(r"function calls?.*in (\d+)\.\d+ (seconds?)", raw_stats)
total_time = None
if m:
total_time = int(m.group(1))
if total_time is None:
console.print("Failed to get total time from stats")
total_time_ms = total_time / 1e6
raw_stats = re.sub(
r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats
)
match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +"
m = re.findall(match_pattern, raw_stats, re.MULTILINE)
ms_times = []
for tottime, percall, cumtime, percall_cum in m:
tottime_ms = int(tottime) / 1e6
percall_ms = int(percall) / 1e6
cumtime_ms = int(cumtime) / 1e6
percall_cum_ms = int(percall_cum) / 1e6
ms_times.append([tottime_ms, percall_ms, cumtime_ms, percall_cum_ms])
split_stats = raw_stats.split("\n")
new_stats = []
console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False)))
table = Table(
show_header=True,
header_style="bold magenta",
border_style="blue",
title="[bold]Function Profile[/bold] (ordered by internal time)",
title_style="cyan",
caption=f"Showing top 25 of {len(self.stats)} functions",
replace_pattern = r"^( *[\d\/]+) +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(.*)"
times_index = 0
for line in split_stats:
if times_index >= len(ms_times):
replaced = line
else:
replaced, n = re.subn(
replace_pattern,
rf"\g<1>{ms_times[times_index][0]:8.3f} {ms_times[times_index][1]:8.3f} {ms_times[times_index][2]:8.3f} {ms_times[times_index][3]:8.3f} \g<6>",
line,
count=1,
)
if n > 0:
times_index += 1
new_stats.append(replaced)
table.add_column("Calls", justify="right", style="green", width=10)
table.add_column("Time (ms)", justify="right", style="cyan", width=10)
table.add_column("Per Call", justify="right", style="cyan", width=10)
table.add_column("Cum (ms)", justify="right", style="yellow", width=10)
table.add_column("Cum/Call", justify="right", style="yellow", width=10)
table.add_column("Function", style="blue")
console.print("\n".join(new_stats))
sorted_stats = sorted(
((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3),
key=lambda x: x[1][2], # Sort by tt (internal time)
reverse=True,
)[:25] # Limit to top 25
# Format and add each row to the table
for func, (cc, nc, tt, ct, _) in sorted_stats:
filename, lineno, funcname = func
# Format calls - show recursive format if different
calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}"
# Convert to milliseconds
tt_ms = tt / 1e6
ct_ms = ct / 1e6
# Calculate per-call times
per_call = tt_ms / cc if cc > 0 else 0
cum_per_call = ct_ms / nc if nc > 0 else 0
base_filename = Path(filename).name
file_link = f"[link=file://{filename}]{base_filename}[/link]"
table.add_row(
calls_str,
f"{tt_ms:.3f}",
f"{per_call:.3f}",
f"{ct_ms:.3f}",
f"{cum_per_call:.3f}",
f"{funcname} [dim]({file_link}:{lineno})[/dim]",
)
console.print(Align.center(table))
except Exception as e: # noqa: BLE001
console.print(f"[bold red]Error in stats processing:[/bold red] {e}")
console.print(f"Traced {self.trace_count:,} function calls")
self.total_tt = 0
def make_pstats_compatible(self) -> None:
def make_pstats_compatible(self):
# delete the extra class_name item from the function tuple
self.files = []
self.top_level = []
@ -693,33 +552,36 @@ class Tracer:
self.stats = new_stats
self.timings = new_timings
def dump_stats(self, file: str) -> None:
with Path(file).open("wb") as f:
def dump_stats(self, file):
with open(file, "wb") as f:
self.create_stats()
marshal.dump(self.stats, f)
def create_stats(self) -> None:
def create_stats(self):
self.simulate_cmd_complete()
self.snapshot_stats()
def snapshot_stats(self) -> None:
def snapshot_stats(self):
self.stats = {}
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
callers = caller_dict.copy()
for func, (cc, ns, tt, ct, callers) in self.timings.items():
callers = callers.copy()
nc = 0
for callcnt in callers.values():
nc += callcnt
self.stats[func] = cc, nc, tt, ct, callers
def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None:
def runctx(self, cmd, globals, locals):
self.__enter__()
try:
exec(cmd, global_vars, local_vars) # noqa: S102
exec(cmd, globals, locals)
finally:
self.__exit__(None, None, None)
return self
def main() -> ArgumentParser:
def main():
from argparse import ArgumentParser
parser = ArgumentParser(allow_abbrev=False)
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
@ -776,13 +638,16 @@ def main() -> ArgumentParser:
"__cached__": None,
}
try:
Tracer(
tracer = Tracer(
output=args.outfile,
functions=args.only_functions,
max_function_count=args.max_function_count,
timeout=args.tracer_timeout,
config_file_path=args.codeflash_config,
).runctx(code, globs, None)
)
tracer.runctx(code, globs, None)
print(tracer.functions)
except BrokenPipeError as exc:
# Prevent "Exception ignored" during interpreter shutdown.

View file

@ -29,7 +29,6 @@ class TestType(Enum):
REPLAY_TEST = 4
CONCOLIC_COVERAGE_TEST = 5
INIT_STATE_TEST = 6
BENCHMARK_TEST = 7
def to_name(self) -> str:
if self == TestType.INIT_STATE_TEST:
@ -40,7 +39,6 @@ class TestType(Enum):
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
TestType.BENCHMARK_TEST: "📏 Benchmark Tests",
}
return names[self]

View file

@ -75,3 +75,4 @@ class TestConfig:
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
concolic_test_root_dir: Optional[Path] = None
pytest_cmd: str = "pytest"
benchmark_tests_root: Optional[Path] = None

View file

@ -69,7 +69,7 @@ exclude = [
[tool.poetry.dependencies]
python = ">=3.9"
unidiff = ">=0.7.4"
pytest = ">=7.0.0,<8.3.4"
pytest = ">=7.0.0"
gitpython = ">=3.1.31"
libcst = ">=1.0.1"
jedi = ">=0.19.1"
@ -92,7 +92,6 @@ rich = ">=13.8.1"
lxml = ">=5.3.0"
crosshair-tool = ">=0.0.78"
coverage = ">=7.6.4"
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
[tool.poetry.group.dev]
optional = true
@ -120,7 +119,7 @@ types-gevent = "^24.11.0.20241230"
types-greenlet = "^3.1.0.20241221"
types-pexpect = "^4.9.0.20241208"
types-unidiff = "^0.7.0.20240505"
uv = ">=0.6.2"
sqlalchemy = "^2.0.38"
[tool.poetry.build]
script = "codeflash/update_license_version.py"
@ -152,7 +151,7 @@ warn_required_dynamic_aliases = true
line-length = 120
fix = true
show-fixes = true
exclude = ["code_to_optimize/", "pie_test_set/", "tests/"]
exclude = ["code_to_optimize/", "pie_test_set/"]
[tool.ruff.lint]
select = ["ALL"]
@ -164,11 +163,10 @@ ignore = [
"D103",
"D105",
"D107",
"D203", # incorrect-blank-line-before-class (incompatible with D211)
"D213", # multi-line-summary-second-line (incompatible with D212)
"S101",
"S603",
"S607",
"ANN101",
"COM812",
"FIX002",
"PLR0912",
@ -177,14 +175,13 @@ ignore = [
"TD002",
"TD003",
"TD004",
"PLR2004",
"UP007" # remove once we drop 3.9 support.
"PLR2004"
]
[tool.ruff.lint.flake8-type-checking]
strict = true
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"]
runtime-evaluated-decorators = ["pydantic.validate_call"]
[tool.ruff.lint.pep8-naming]
classmethod-decorators = [
@ -192,9 +189,6 @@ classmethod-decorators = [
"pydantic.validator",
]
[tool.ruff.lint.isort]
split-on-trailing-comma = false
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
@ -217,13 +211,13 @@ initial-content = """
[tool.codeflash]
module-root = "codeflash"
tests-root = "tests"
# All paths are relative to this pyproject.toml's directory.
module-root = "code_to_optimize"
tests-root = "code_to_optimize/tests"
benchmarks-root = "code_to_optimize/tests/pytest/benchmarks"
test-framework = "pytest"
formatter-cmds = [
"uvx ruff check --exit-zero --fix $file",
"uvx ruff format $file",
]
ignore-paths = []
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
[build-system]

View file

@ -0,0 +1,8 @@
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from pathlib import Path
def test_trace_benchmarks():
# Test the trace_benchmarks function
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks"
trace_benchmarks_pytest(benchmarks_root, project_root, ["sorter"])

View file

@ -3,6 +3,7 @@ import tempfile
from pathlib import Path
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.verification.test_results import TestType
from codeflash.verification.verification_utils import TestConfig
@ -21,7 +22,7 @@ def test_unit_test_discovery_pytest():
def test_benchmark_test_discovery_pytest():
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest"
tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
test_config = TestConfig(
tests_root=tests_path,
project_root_path=project_path,
@ -29,9 +30,10 @@ def test_benchmark_test_discovery_pytest():
tests_project_rootdir=tests_path.parent,
)
tests = discover_unit_tests(test_config)
print(tests)
assert len(tests) > 0
# print(tests)
assert 'bubble_sort.sorter' in tests
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
assert benchmark_tests == 1
def test_unit_test_discovery_unittest():