mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
initial implementation for tracing benchmarks using a plugin, and projecting speedup
This commit is contained in:
parent
a73b541159
commit
965e2c818c
22 changed files with 790 additions and 571 deletions
28
code_to_optimize/process_and_bubble_sort.py
Normal file
28
code_to_optimize/process_and_bubble_sort.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
def calculate_pairwise_products(arr):
|
||||
"""
|
||||
Calculate the average of all pairwise products in the array.
|
||||
"""
|
||||
sum_of_products = 0
|
||||
count = 0
|
||||
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr)):
|
||||
if i != j:
|
||||
sum_of_products += arr[i] * arr[j]
|
||||
count += 1
|
||||
|
||||
# The average of all pairwise products
|
||||
return sum_of_products / count if count > 0 else 0
|
||||
|
||||
|
||||
def compute_and_sort(arr):
|
||||
# Compute pairwise sums average
|
||||
pairwise_average = calculate_pairwise_products(arr)
|
||||
|
||||
# Call sorter function
|
||||
sorter(arr.copy())
|
||||
|
||||
return pairwise_average
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
import pytest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
def test_sort(benchmark):
|
||||
result = benchmark(sorter, list(reversed(range(5000))))
|
||||
assert result == list(range(5000))
|
||||
|
||||
# This should not be picked up as a benchmark test
|
||||
def test_sort2():
|
||||
result = sorter(list(reversed(range(5000))))
|
||||
assert result == list(range(5000))
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from code_to_optimize.process_and_bubble_sort import compute_and_sort
|
||||
from code_to_optimize.bubble_sort2 import sorter
|
||||
def test_compute_and_sort(benchmark):
|
||||
result = benchmark(compute_and_sort, list(reversed(range(5000))))
|
||||
assert result == 6247083.5
|
||||
|
||||
def test_no_func(benchmark):
|
||||
benchmark(sorter, list(reversed(range(5000))))
|
||||
0
codeflash/benchmarking/__init__.py
Normal file
0
codeflash/benchmarking/__init__.py
Normal file
112
codeflash/benchmarking/get_trace_info.py
Normal file
112
codeflash/benchmarking/get_trace_info.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
def get_function_benchmark_timings(trace_dir: Path, all_functions_to_optimize: list[FunctionToOptimize]) -> dict[str, dict[str, float]]:
|
||||
"""Process all trace files in the given directory and extract timing data for the specified functions.
|
||||
|
||||
Args:
|
||||
trace_dir: Path to the directory containing .trace files
|
||||
all_functions_to_optimize: Set of FunctionToOptimize objects representing functions to include
|
||||
|
||||
Returns:
|
||||
A nested dictionary where:
|
||||
- Outer keys are function qualified names with file name
|
||||
- Inner keys are benchmark names (trace filename without .trace extension)
|
||||
- Values are function timing in milliseconds
|
||||
|
||||
"""
|
||||
# Create a mapping of (filename, function_name, class_name) -> qualified_name for efficient lookups
|
||||
function_lookup = {}
|
||||
function_benchmark_timings = {}
|
||||
|
||||
for func in all_functions_to_optimize:
|
||||
qualified_name = func.qualified_name_with_file_name
|
||||
|
||||
# Extract components (assumes Path.name gives only filename without directory)
|
||||
filename = func.file_path
|
||||
function_name = func.function_name
|
||||
|
||||
# Get class name if there's a parent
|
||||
class_name = func.parents[0].name if func.parents else None
|
||||
|
||||
# Store in lookup dictionary
|
||||
key = (filename, function_name, class_name)
|
||||
function_lookup[key] = qualified_name
|
||||
function_benchmark_timings[qualified_name] = {}
|
||||
|
||||
# Find all .trace files in the directory
|
||||
trace_files = list(trace_dir.glob("*.trace"))
|
||||
|
||||
for trace_file in trace_files:
|
||||
# Extract benchmark name from filename (without .trace)
|
||||
benchmark_name = trace_file.stem
|
||||
|
||||
# Connect to the trace database
|
||||
conn = sqlite3.connect(trace_file)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# For each function we're interested in, query the database directly
|
||||
for (filename, function_name, class_name), qualified_name in function_lookup.items():
|
||||
# Adjust query based on whether we have a class name
|
||||
if class_name:
|
||||
cursor.execute(
|
||||
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND class_name = ?",
|
||||
(f"%{filename}", function_name, class_name)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND (class_name IS NULL OR class_name = '')",
|
||||
(f"%{filename}", function_name)
|
||||
)
|
||||
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
time_ns = result[0]
|
||||
function_benchmark_timings[qualified_name][benchmark_name] = time_ns / 1e6 # Convert to milliseconds
|
||||
|
||||
conn.close()
|
||||
|
||||
return function_benchmark_timings
|
||||
|
||||
|
||||
def get_benchmark_timings(trace_dir: Path) -> dict[str, float]:
|
||||
"""Extract total benchmark timings from trace files.
|
||||
|
||||
Args:
|
||||
trace_dir: Path to the directory containing .trace files
|
||||
|
||||
Returns:
|
||||
A dictionary mapping benchmark names to their total execution time in milliseconds.
|
||||
"""
|
||||
benchmark_timings = {}
|
||||
|
||||
# Find all .trace files in the directory
|
||||
trace_files = list(trace_dir.glob("*.trace"))
|
||||
|
||||
for trace_file in trace_files:
|
||||
# Extract benchmark name from filename (without .trace extension)
|
||||
benchmark_name = trace_file.stem
|
||||
|
||||
# Connect to the trace database
|
||||
conn = sqlite3.connect(trace_file)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query the total_time table for the benchmark's total execution time
|
||||
try:
|
||||
cursor.execute("SELECT time_ns FROM total_time")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
time_ns = result[0]
|
||||
# Convert nanoseconds to milliseconds
|
||||
benchmark_timings[benchmark_name] = time_ns / 1e6
|
||||
except sqlite3.OperationalError:
|
||||
# Handle case where total_time table might not exist
|
||||
print(f"Warning: Could not get total time for benchmark {benchmark_name}")
|
||||
|
||||
conn.close()
|
||||
|
||||
return benchmark_timings
|
||||
0
codeflash/benchmarking/plugin/__init__.py
Normal file
0
codeflash/benchmarking/plugin/__init__.py
Normal file
79
codeflash/benchmarking/plugin/plugin.py
Normal file
79
codeflash/benchmarking/plugin/plugin.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import pytest
|
||||
|
||||
from codeflash.tracer import Tracer
|
||||
from pathlib import Path
|
||||
|
||||
class CodeFlashPlugin:
|
||||
@staticmethod
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--codeflash-trace",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable CodeFlash tracing"
|
||||
)
|
||||
parser.addoption(
|
||||
"--functions",
|
||||
action="store",
|
||||
default="",
|
||||
help="Comma-separated list of additional functions to trace"
|
||||
)
|
||||
parser.addoption(
|
||||
"--benchmarks-root",
|
||||
action="store",
|
||||
default=".",
|
||||
help="Root directory for benchmarks"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pytest_plugin_registered(plugin, manager):
|
||||
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
|
||||
manager.unregister(plugin)
|
||||
|
||||
@staticmethod
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if not config.getoption("--codeflash-trace"):
|
||||
return
|
||||
|
||||
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||
for item in items:
|
||||
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
|
||||
continue
|
||||
item.add_marker(skip_no_benchmark)
|
||||
|
||||
@staticmethod
|
||||
@pytest.fixture
|
||||
def benchmark(request):
|
||||
if not request.config.getoption("--codeflash-trace"):
|
||||
return None
|
||||
|
||||
class Benchmark:
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
func_name = func.__name__
|
||||
test_name = request.node.name
|
||||
additional_functions = request.config.getoption("--functions").split(",")
|
||||
trace_functions = [f for f in additional_functions if f]
|
||||
print("Tracing functions: ", trace_functions)
|
||||
|
||||
# Get benchmarks root directory from command line option
|
||||
benchmarks_root = Path(request.config.getoption("--benchmarks-root"))
|
||||
|
||||
# Create .trace directory if it doesn't exist
|
||||
trace_dir = benchmarks_root / '.codeflash_trace'
|
||||
trace_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Set output path to the .trace directory
|
||||
output_path = trace_dir / f"{test_name}.trace"
|
||||
|
||||
tracer = Tracer(
|
||||
output=str(output_path), # Convert Path to string for Tracer
|
||||
functions=trace_functions,
|
||||
max_function_count=256
|
||||
)
|
||||
|
||||
with tracer:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return Benchmark()
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import sys
|
||||
from plugin.plugin import CodeFlashPlugin
|
||||
|
||||
benchmarks_root = sys.argv[1]
|
||||
function_list = sys.argv[2]
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
try:
|
||||
exitcode = pytest.main(
|
||||
[benchmarks_root, "--benchmarks-root", benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "--functions", function_list], plugins=[CodeFlashPlugin()]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to collect tests: {e!s}")
|
||||
exitcode = -1
|
||||
20
codeflash/benchmarking/trace_benchmarks.py
Normal file
20
codeflash/benchmarking/trace_benchmarks.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from __future__ import annotations
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
def trace_benchmarks_pytest(benchmarks_root: Path, project_root: Path, function_list: list[str] = []) -> None:
|
||||
result = subprocess.run(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
|
||||
str(benchmarks_root),
|
||||
",".join(function_list)
|
||||
],
|
||||
cwd=project_root,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print("stdout:", result.stdout)
|
||||
print("stderr:", result.stderr)
|
||||
|
|
@ -62,6 +62,10 @@ def parse_args() -> Namespace:
|
|||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
|
||||
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
|
||||
parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks")
|
||||
parser.add_argument(
|
||||
"--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located."
|
||||
)
|
||||
args: Namespace = parser.parse_args()
|
||||
return process_and_validate_cmd_args(args)
|
||||
|
||||
|
|
@ -116,6 +120,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
"disable_telemetry",
|
||||
"disable_imports_sorting",
|
||||
"git_remote",
|
||||
"benchmarks_root"
|
||||
]
|
||||
for key in supported_keys:
|
||||
if key in pyproject_config and (
|
||||
|
|
@ -127,23 +132,17 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
||||
assert args.tests_root is not None, "--tests-root must be specified"
|
||||
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
|
||||
|
||||
if env_utils.get_pr_number() is not None:
|
||||
assert env_utils.ensure_codeflash_api_key(), (
|
||||
"Codeflash API key not found. When running in a Github Actions Context, provide the "
|
||||
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
|
||||
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
|
||||
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
|
||||
f"Here's a direct link: {get_github_secrets_page_url()}\n"
|
||||
"Exiting..."
|
||||
)
|
||||
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
|
||||
owner, repo_name = get_repo_owner_and_name(repo)
|
||||
|
||||
require_github_app_or_exit(owner, repo_name)
|
||||
|
||||
if args.benchmark:
|
||||
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
|
||||
assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
|
||||
assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
|
||||
"Codeflash API key not found. When running in a Github Actions Context, provide the "
|
||||
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
|
||||
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
|
||||
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
|
||||
f"Here's a direct link: {get_github_secrets_page_url()}\n"
|
||||
"Exiting..."
|
||||
)
|
||||
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
|
||||
normalized_ignore_paths = []
|
||||
for path in args.ignore_paths:
|
||||
|
|
|
|||
|
|
@ -107,8 +107,6 @@ def discover_tests_pytest(
|
|||
test_type = TestType.REPLAY_TEST
|
||||
elif "test_concolic_coverage" in test["test_file"]:
|
||||
test_type = TestType.CONCOLIC_COVERAGE_TEST
|
||||
elif test["test_type"] == "benchmark": # New condition for benchmark tests
|
||||
test_type = TestType.BENCHMARK_TEST
|
||||
else:
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
|
||||
|
|
|
|||
|
|
@ -121,7 +121,6 @@ class FunctionToOptimize:
|
|||
method extends this with the module name from the project root.
|
||||
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
file_path: Path
|
||||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
|
|
@ -145,6 +144,11 @@ class FunctionToOptimize:
|
|||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
||||
@property
|
||||
def qualified_name_with_file_name(self) -> str:
|
||||
class_name = self.parents[0].name if self.parents else None
|
||||
return f"{self.file_path}:{(class_name + ':' if class_name else '')}{self.function_name}"
|
||||
|
||||
|
||||
def get_functions_to_optimize(
|
||||
optimize_all: str | None,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
import sys
|
||||
from typing import Any
|
||||
|
||||
# This script should not have any relation to the codeflash package, be careful with imports
|
||||
cwd = sys.argv[1]
|
||||
tests_root = sys.argv[2]
|
||||
pickle_path = sys.argv[3]
|
||||
collected_tests = []
|
||||
pytest_rootdir = None
|
||||
sys.path.insert(1, str(cwd))
|
||||
|
||||
|
||||
class PytestCollectionPlugin:
|
||||
def pytest_collection_finish(self, session) -> None:
|
||||
global pytest_rootdir
|
||||
collected_tests.extend(session.items)
|
||||
pytest_rootdir = session.config.rootdir
|
||||
|
||||
|
||||
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
|
||||
test_results = []
|
||||
for test in pytest_tests:
|
||||
test_class = None
|
||||
if test.cls:
|
||||
test_class = test.parent.name
|
||||
|
||||
# Determine if this is a benchmark test by checking for the benchmark fixture
|
||||
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
|
||||
test_type = 'benchmark' if is_benchmark else 'regular'
|
||||
|
||||
test_results.append({
|
||||
"test_file": str(test.path),
|
||||
"test_class": test_class,
|
||||
"test_function": test.name,
|
||||
"test_type": test_type
|
||||
})
|
||||
return test_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
try:
|
||||
exitcode = pytest.main(
|
||||
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to collect tests: {e!s}")
|
||||
exitcode = -1
|
||||
tests = parse_pytest_collection_results(collected_tests)
|
||||
import pickle
|
||||
|
||||
with open(pickle_path, "wb") as f:
|
||||
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
|
@ -29,17 +29,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s
|
|||
test_class = None
|
||||
if test.cls:
|
||||
test_class = test.parent.name
|
||||
|
||||
# Determine if this is a benchmark test by checking for the benchmark fixture
|
||||
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
|
||||
test_type = 'benchmark' if is_benchmark else 'regular'
|
||||
|
||||
test_results.append({
|
||||
"test_file": str(test.path),
|
||||
"test_class": test_class,
|
||||
"test_function": test.name,
|
||||
"test_type": test_type
|
||||
})
|
||||
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
|
||||
return test_results
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import shutil
|
|||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -21,12 +21,12 @@ from rich.tree import Tree
|
|||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
cleanup_paths,
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
has_any_async_functions,
|
||||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.code_utils.config_consts import (
|
||||
|
|
@ -37,7 +37,6 @@ from codeflash.code_utils.config_consts import (
|
|||
)
|
||||
from codeflash.code_utils.formatter import format_code, sort_imports
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.code_utils.line_profile_utils import add_decorator_imports
|
||||
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
|
|
@ -49,6 +48,7 @@ from codeflash.models.models import (
|
|||
BestOptimization,
|
||||
CodeOptimizationContext,
|
||||
FunctionCalledInTest,
|
||||
FunctionParent,
|
||||
GeneratedTests,
|
||||
GeneratedTestsList,
|
||||
OptimizationSet,
|
||||
|
|
@ -57,9 +57,8 @@ from codeflash.models.models import (
|
|||
TestFile,
|
||||
TestFiles,
|
||||
TestingMode,
|
||||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.explanation import Explanation
|
||||
|
|
@ -67,15 +66,18 @@ from codeflash.telemetry.posthog_cf import ph
|
|||
from codeflash.verification.concolic_testing import generate_concolic_tests
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
from codeflash.verification.parse_test_output import parse_test_results
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
||||
from codeflash.verification.test_results import TestResults, TestType
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from codeflash.either import Result
|
||||
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
@ -90,6 +92,8 @@ class FunctionOptimizer:
|
|||
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
aiservice_client: AiServiceClient | None = None,
|
||||
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
|
||||
total_benchmark_timings: dict[str, float] | None = None,
|
||||
args: Namespace | None = None,
|
||||
) -> None:
|
||||
self.project_root = test_cfg.project_root_path
|
||||
|
|
@ -118,6 +122,9 @@ class FunctionOptimizer:
|
|||
self.function_trace_id: str = str(uuid.uuid4())
|
||||
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
|
||||
|
||||
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
|
||||
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
|
||||
|
||||
def optimize_function(self) -> Result[BestOptimization, str]:
|
||||
should_run_experiment = self.experiment_id is not None
|
||||
logger.debug(f"Function Trace ID: {self.function_trace_id}")
|
||||
|
|
@ -134,10 +141,19 @@ class FunctionOptimizer:
|
|||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
if has_any_async_functions(code_context.read_writable_code):
|
||||
return Failure("Codeflash does not support async functions in the code to optimize.")
|
||||
|
||||
logger.info("Code to be optimized:")
|
||||
code_print(code_context.read_writable_code)
|
||||
|
||||
for module_abspath, helper_code_source in original_helper_code.items():
|
||||
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
|
||||
helper_code_source,
|
||||
code_context.code_to_optimize_with_helpers,
|
||||
module_abspath,
|
||||
self.function_to_optimize.file_path,
|
||||
self.args.project_root,
|
||||
)
|
||||
|
||||
generated_test_paths = [
|
||||
get_test_file_path(
|
||||
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
|
||||
|
|
@ -156,7 +172,7 @@ class FunctionOptimizer:
|
|||
transient=True,
|
||||
):
|
||||
generated_results = self.generate_tests_and_optimizations(
|
||||
testgen_context_code=code_context.testgen_context_code,
|
||||
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
|
||||
read_writable_code=code_context.read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context_code,
|
||||
helper_functions=code_context.helper_functions,
|
||||
|
|
@ -232,11 +248,10 @@ class FunctionOptimizer:
|
|||
):
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
return Failure("The threshold for test coverage was not met.")
|
||||
# request for new optimizations but don't block execution, check for completion later
|
||||
# adding to control and experiment set but with same traceid
|
||||
|
||||
best_optimization = None
|
||||
|
||||
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
|
||||
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
|
||||
if candidates is None:
|
||||
continue
|
||||
|
||||
|
|
@ -270,6 +285,20 @@ class FunctionOptimizer:
|
|||
function_name=function_to_optimize_qualified_name,
|
||||
file_path=self.function_to_optimize.file_path,
|
||||
)
|
||||
speedup = explanation.speedup # eg. 1.2 means 1.2x faster
|
||||
if self.args.benchmark:
|
||||
fto_benchmark_timings = self.function_benchmark_timings[self.function_to_optimize.qualified_name_with_file_name]
|
||||
for benchmark_name, og_benchmark_timing in fto_benchmark_timings.items():
|
||||
print(f"Calculating speedup for benchmark {benchmark_name}")
|
||||
total_benchmark_timing = self.total_benchmark_timings[benchmark_name]
|
||||
# find out expected new benchmark timing, then calculate how much total benchmark was sped up. print out intermediate values
|
||||
expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + og_benchmark_timing / speedup
|
||||
print(f"Expected new benchmark timing: {expected_new_benchmark_timing}")
|
||||
print(f"Original benchmark timing: {total_benchmark_timing}")
|
||||
print(f"Benchmark speedup: {total_benchmark_timing / expected_new_benchmark_timing}")
|
||||
|
||||
speedup = total_benchmark_timing / expected_new_benchmark_timing
|
||||
print(f"Speedup: {speedup}")
|
||||
|
||||
self.log_successful_optimization(explanation, generated_tests)
|
||||
|
||||
|
|
@ -359,123 +388,94 @@ class FunctionOptimizer:
|
|||
f"{self.function_to_optimize.qualified_name}…"
|
||||
)
|
||||
console.rule()
|
||||
candidates = deque(candidates)
|
||||
# Start a new thread for AI service request, start loop in main thread
|
||||
# check if aiservice request is complete, when it is complete, append result to the candidates list
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
future_line_profile_results = executor.submit(
|
||||
self.aiservice_client.optimize_python_code_line_profiler,
|
||||
source_code=code_context.read_writable_code,
|
||||
dependency_code=code_context.read_only_context_code,
|
||||
trace_id=self.function_trace_id,
|
||||
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
|
||||
num_candidates=10,
|
||||
experiment_metadata=None,
|
||||
)
|
||||
try:
|
||||
candidate_index = 0
|
||||
done = False
|
||||
original_len = len(candidates)
|
||||
while candidates:
|
||||
# for candidate_index, candidate in enumerate(candidates, start=1):
|
||||
done = True if future_line_profile_results is None else future_line_profile_results.done()
|
||||
if done and (future_line_profile_results is not None):
|
||||
line_profile_results = future_line_profile_results.result()
|
||||
candidates.extend(line_profile_results)
|
||||
original_len+= len(line_profile_results)
|
||||
logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}")
|
||||
future_line_profile_results = None
|
||||
candidate_index += 1
|
||||
candidate = candidates.popleft()
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
|
||||
code_print(candidate.source_code)
|
||||
try:
|
||||
did_update = self.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=candidate.source_code
|
||||
)
|
||||
if not did_update:
|
||||
logger.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||
)
|
||||
console.rule()
|
||||
continue
|
||||
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
||||
logger.error(e)
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code,
|
||||
original_helper_code,
|
||||
self.function_to_optimize.file_path,
|
||||
)
|
||||
continue
|
||||
|
||||
# Instrument codeflash capture
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_candidate_index=candidate_index,
|
||||
baseline_results=original_code_baseline,
|
||||
original_helper_code=original_helper_code,
|
||||
file_path_to_helper_classes=file_path_to_helper_classes,
|
||||
try:
|
||||
for candidate_index, candidate in enumerate(candidates, start=1):
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
|
||||
code_print(candidate.source_code)
|
||||
try:
|
||||
did_update = self.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=candidate.source_code
|
||||
)
|
||||
console.rule()
|
||||
|
||||
if not is_successful(run_results):
|
||||
optimized_runtimes[candidate.optimization_id] = None
|
||||
is_correct[candidate.optimization_id] = False
|
||||
speedup_ratios[candidate.optimization_id] = None
|
||||
else:
|
||||
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
||||
best_test_runtime = candidate_result.best_test_runtime
|
||||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||
is_correct[candidate.optimization_id] = True
|
||||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
|
||||
if not did_update:
|
||||
logger.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||
)
|
||||
speedup_ratios[candidate.optimization_id] = perf_gain
|
||||
|
||||
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
|
||||
if speedup_critic(
|
||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now
|
||||
) and quantity_of_tests_critic(candidate_result):
|
||||
tree.add("This candidate is faster than the previous best candidate. 🚀")
|
||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||
tree.add(
|
||||
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
|
||||
f"(measured over {candidate_result.max_loop_count} "
|
||||
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
|
||||
|
||||
best_optimization = BestOptimization(
|
||||
candidate=candidate,
|
||||
helper_functions=code_context.helper_functions,
|
||||
runtime=best_test_runtime,
|
||||
winning_behavioral_test_results=candidate_result.behavior_test_results,
|
||||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
)
|
||||
best_runtime_until_now = best_test_runtime
|
||||
else:
|
||||
tree.add(
|
||||
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
|
||||
f"(measured over {candidate_result.max_loop_count} "
|
||||
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
console.print(tree)
|
||||
console.rule()
|
||||
|
||||
continue
|
||||
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
||||
logger.error(e)
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
continue
|
||||
|
||||
# Instrument codeflash capture
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_candidate_index=candidate_index,
|
||||
baseline_results=original_code_baseline,
|
||||
original_helper_code=original_helper_code,
|
||||
file_path_to_helper_classes=file_path_to_helper_classes,
|
||||
)
|
||||
console.rule()
|
||||
|
||||
if not is_successful(run_results):
|
||||
optimized_runtimes[candidate.optimization_id] = None
|
||||
is_correct[candidate.optimization_id] = False
|
||||
speedup_ratios[candidate.optimization_id] = None
|
||||
else:
|
||||
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
||||
best_test_runtime = candidate_result.best_test_runtime
|
||||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||
is_correct[candidate.optimization_id] = True
|
||||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
|
||||
)
|
||||
speedup_ratios[candidate.optimization_id] = perf_gain
|
||||
|
||||
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
|
||||
if speedup_critic(
|
||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now
|
||||
) and quantity_of_tests_critic(candidate_result):
|
||||
tree.add("This candidate is faster than the previous best candidate. 🚀")
|
||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||
tree.add(
|
||||
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
|
||||
f"(measured over {candidate_result.max_loop_count} "
|
||||
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
|
||||
|
||||
best_optimization = BestOptimization(
|
||||
candidate=candidate,
|
||||
helper_functions=code_context.helper_functions,
|
||||
runtime=best_test_runtime,
|
||||
winning_behavioral_test_results=candidate_result.behavior_test_results,
|
||||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
)
|
||||
best_runtime_until_now = best_test_runtime
|
||||
else:
|
||||
tree.add(
|
||||
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
|
||||
f"(measured over {candidate_result.max_loop_count} "
|
||||
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
console.print(tree)
|
||||
console.rule()
|
||||
|
||||
except KeyboardInterrupt as e:
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
logger.exception(f"Optimization interrupted: {e}")
|
||||
raise
|
||||
except KeyboardInterrupt as e:
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
logger.exception(f"Optimization interrupted: {e}")
|
||||
raise
|
||||
|
||||
self.aiservice_client.log_results(
|
||||
function_trace_id=self.function_trace_id,
|
||||
|
|
@ -575,6 +575,50 @@ class FunctionOptimizer:
|
|||
return did_update
|
||||
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
|
||||
self.function_to_optimize, self.project_root, code_to_optimize
|
||||
)
|
||||
if self.function_to_optimize.parents:
|
||||
function_class = self.function_to_optimize.parents[0].name
|
||||
same_class_helper_methods = [
|
||||
df
|
||||
for df in helper_functions
|
||||
if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
|
||||
]
|
||||
optimizable_methods = [
|
||||
FunctionToOptimize(
|
||||
df.qualified_name.split(".")[-1],
|
||||
df.file_path,
|
||||
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
for df in same_class_helper_methods
|
||||
] + [self.function_to_optimize]
|
||||
dedup_optimizable_methods = []
|
||||
added_methods = set()
|
||||
for method in reversed(optimizable_methods):
|
||||
if f"{method.file_path}.{method.qualified_name}" not in added_methods:
|
||||
dedup_optimizable_methods.append(method)
|
||||
added_methods.add(f"{method.file_path}.{method.qualified_name}")
|
||||
if len(dedup_optimizable_methods) > 1:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
|
||||
|
||||
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
|
||||
self.function_to_optimize_source_code,
|
||||
code_to_optimize_with_helpers,
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize.file_path,
|
||||
self.project_root,
|
||||
helper_functions,
|
||||
)
|
||||
|
||||
try:
|
||||
new_code_ctx = code_context_extractor.get_code_optimization_context(
|
||||
self.function_to_optimize, self.project_root
|
||||
|
|
@ -584,7 +628,7 @@ class FunctionOptimizer:
|
|||
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
testgen_context_code=new_code_ctx.testgen_context_code,
|
||||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
read_writable_code=new_code_ctx.read_writable_code,
|
||||
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
|
||||
|
|
@ -686,7 +730,7 @@ class FunctionOptimizer:
|
|||
|
||||
def generate_tests_and_optimizations(
|
||||
self,
|
||||
testgen_context_code: str,
|
||||
code_to_optimize_with_helpers: str,
|
||||
read_writable_code: str,
|
||||
read_only_context_code: str,
|
||||
helper_functions: list[FunctionSource],
|
||||
|
|
@ -701,7 +745,7 @@ class FunctionOptimizer:
|
|||
# Submit the test generation task as future
|
||||
future_tests = self.generate_and_instrument_tests(
|
||||
executor,
|
||||
testgen_context_code,
|
||||
code_to_optimize_with_helpers,
|
||||
[definition.fully_qualified_name for definition in helper_functions],
|
||||
generated_test_paths,
|
||||
generated_perf_test_paths,
|
||||
|
|
@ -790,7 +834,6 @@ class FunctionOptimizer:
|
|||
original_helper_code: dict[Path, str],
|
||||
file_path_to_helper_classes: dict[Path, set[str]],
|
||||
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
|
||||
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
|
||||
# For the original function - run the tests and get the runtime, plus coverage
|
||||
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
|
||||
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
||||
|
|
@ -831,31 +874,11 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
|
||||
if not coverage_critic(coverage_results, self.args.test_framework):
|
||||
if not coverage_critic(
|
||||
coverage_results, self.args.test_framework
|
||||
):
|
||||
return Failure("The threshold for test coverage was not met.")
|
||||
if test_framework == "pytest":
|
||||
try:
|
||||
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
|
||||
line_profile_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.LINE_PROFILE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=TOTAL_LOOPING_TIME,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
finally:
|
||||
# Remove codeflash capture
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
if line_profile_results["str_out"] == "":
|
||||
logger.warning(
|
||||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
console.rule()
|
||||
benchmarking_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
|
|
@ -894,6 +917,7 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
|
||||
|
||||
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
|
||||
functions_to_remove = [
|
||||
result.id.test_function_name
|
||||
|
|
@ -927,7 +951,6 @@ class FunctionOptimizer:
|
|||
benchmarking_test_results=benchmarking_results,
|
||||
runtime=total_timing,
|
||||
coverage_results=coverage_results,
|
||||
line_profile_results=line_profile_results,
|
||||
),
|
||||
functions_to_remove,
|
||||
)
|
||||
|
|
@ -1063,77 +1086,59 @@ class FunctionOptimizer:
|
|||
pytest_max_loops: int = 100_000,
|
||||
code_context: CodeOptimizationContext | None = None,
|
||||
unittest_loop_index: int | None = None,
|
||||
line_profiler_output_file: Path | None = None,
|
||||
) -> tuple[TestResults | dict, CoverageData | None]:
|
||||
) -> tuple[TestResults, CoverageData | None]:
|
||||
coverage_database_file = None
|
||||
coverage_config_file = None
|
||||
try:
|
||||
if testing_type == TestingMode.BEHAVIOR:
|
||||
result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests(
|
||||
result_file_path, run_result, coverage_database_file = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
cwd=self.project_root,
|
||||
test_env=test_env,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
verbose=True,
|
||||
enable_coverage=enable_coverage,
|
||||
)
|
||||
elif testing_type == TestingMode.LINE_PROFILE:
|
||||
result_file_path, run_result = run_line_profile_tests(
|
||||
test_files,
|
||||
cwd=self.project_root,
|
||||
test_env=test_env,
|
||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
pytest_min_loops=pytest_min_loops,
|
||||
pytest_max_loops=pytest_min_loops,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
elif testing_type == TestingMode.PERFORMANCE:
|
||||
result_file_path, run_result = run_benchmarking_tests(
|
||||
test_files,
|
||||
cwd=self.project_root,
|
||||
test_env=test_env,
|
||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
pytest_min_loops=pytest_min_loops,
|
||||
pytest_max_loops=pytest_max_loops,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
)
|
||||
else:
|
||||
msg = f"Unexpected testing type: {testing_type}"
|
||||
raise ValueError(msg)
|
||||
raise ValueError(f"Unexpected testing type: {testing_type}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.exception(
|
||||
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
|
||||
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
|
||||
)
|
||||
return TestResults(), None
|
||||
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
|
||||
logger.debug(
|
||||
f"Nonzero return code {run_result.returncode} when running tests in "
|
||||
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n"
|
||||
f'Nonzero return code {run_result.returncode} when running tests in '
|
||||
f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n'
|
||||
f"stdout: {run_result.stdout}\n"
|
||||
f"stderr: {run_result.stderr}\n"
|
||||
)
|
||||
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
|
||||
results, coverage_results = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
test_files=test_files,
|
||||
test_config=self.test_cfg,
|
||||
optimization_iteration=optimization_iteration,
|
||||
run_result=run_result,
|
||||
unittest_loop_index=unittest_loop_index,
|
||||
function_name=self.function_to_optimize.function_name,
|
||||
source_file=self.function_to_optimize.file_path,
|
||||
code_context=code_context,
|
||||
coverage_database_file=coverage_database_file,
|
||||
coverage_config_file=coverage_config_file,
|
||||
)
|
||||
else:
|
||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
# print(test_files)
|
||||
results, coverage_results = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
test_files=test_files,
|
||||
test_config=self.test_cfg,
|
||||
optimization_iteration=optimization_iteration,
|
||||
run_result=run_result,
|
||||
unittest_loop_index=unittest_loop_index,
|
||||
function_name=self.function_to_optimize.function_name,
|
||||
source_file=self.function_to_optimize.file_path,
|
||||
code_context=code_context,
|
||||
coverage_database_file=coverage_database_file,
|
||||
)
|
||||
return results, coverage_results
|
||||
|
||||
def generate_and_instrument_tests(
|
||||
|
|
@ -1163,3 +1168,4 @@ class FunctionOptimizer:
|
|||
zip(generated_test_paths, generated_perf_test_paths)
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.cli_cmds.console import console, logger, progress_bar
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
|
|
@ -16,10 +17,12 @@ from codeflash.code_utils.static_analysis import analyze_imported_modules, get_f
|
|||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import TestType, ValidCode
|
||||
from codeflash.models.models import TestFiles, ValidCode
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.test_results import TestType
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
from codeflash.benchmarking.get_trace_info import get_function_benchmark_timings, get_benchmark_timings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
|
@ -50,6 +53,8 @@ class Optimizer:
|
|||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_source_code: str | None = "",
|
||||
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
|
||||
total_benchmark_timings: dict[str, float] | None = None,
|
||||
) -> FunctionOptimizer:
|
||||
return FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
|
|
@ -59,6 +64,8 @@ class Optimizer:
|
|||
function_to_optimize_ast=function_to_optimize_ast,
|
||||
aiservice_client=self.aiservice_client,
|
||||
args=self.args,
|
||||
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
|
||||
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
|
|
@ -80,6 +87,23 @@ class Optimizer:
|
|||
project_root=self.args.project_root,
|
||||
module_root=self.args.module_root,
|
||||
)
|
||||
if self.args.benchmark:
|
||||
all_functions_to_optimize = [
|
||||
function
|
||||
for functions_list in file_to_funcs_to_optimize.values()
|
||||
for function in functions_list
|
||||
]
|
||||
logger.info(f"Tracing existing benchmarks for {len(all_functions_to_optimize)} functions")
|
||||
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.project_root, [fto.qualified_name_with_file_name for fto in all_functions_to_optimize])
|
||||
logger.info("Finished tracing existing benchmarks")
|
||||
trace_dir = Path(self.args.benchmarks_root) / ".codeflash_trace"
|
||||
function_benchmark_timings = get_function_benchmark_timings(trace_dir, all_functions_to_optimize)
|
||||
print(function_benchmark_timings)
|
||||
total_benchmark_timings = get_benchmark_timings(trace_dir)
|
||||
print("Total benchmark timings:")
|
||||
print(total_benchmark_timings)
|
||||
# for function in fully_qualified_function_names:
|
||||
|
||||
|
||||
optimizations_found: int = 0
|
||||
function_iterator_count: int = 0
|
||||
|
|
@ -93,6 +117,8 @@ class Optimizer:
|
|||
logger.info("No functions found to optimize. Exiting…")
|
||||
return
|
||||
|
||||
console.rule()
|
||||
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}…")
|
||||
console.rule()
|
||||
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
|
||||
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
|
||||
|
|
@ -136,7 +162,6 @@ class Optimizer:
|
|||
validated_original_code[analysis.file_path] = ValidCode(
|
||||
source_code=callee_original_code, normalized_code=normalized_callee_original_code
|
||||
)
|
||||
|
||||
if has_syntax_error:
|
||||
continue
|
||||
|
||||
|
|
@ -146,7 +171,7 @@ class Optimizer:
|
|||
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
|
||||
f"{function_to_optimize.qualified_name}"
|
||||
)
|
||||
console.rule()
|
||||
|
||||
if not (
|
||||
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
|
||||
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
|
||||
|
|
@ -157,12 +182,17 @@ class Optimizer:
|
|||
f"Skipping optimization."
|
||||
)
|
||||
continue
|
||||
function_optimizer = self.create_function_optimizer(
|
||||
function_to_optimize,
|
||||
function_to_optimize_ast,
|
||||
function_to_tests,
|
||||
validated_original_code[original_module_path].source_code,
|
||||
)
|
||||
if self.args.benchmark:
|
||||
|
||||
function_optimizer = self.create_function_optimizer(
|
||||
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings, total_benchmark_timings
|
||||
)
|
||||
else:
|
||||
function_optimizer = self.create_function_optimizer(
|
||||
function_to_optimize, function_to_optimize_ast, function_to_tests,
|
||||
validated_original_code[original_module_path].source_code
|
||||
)
|
||||
|
||||
best_optimization = function_optimizer.optimize_function()
|
||||
if is_successful(best_optimization):
|
||||
optimizations_found += 1
|
||||
|
|
@ -191,6 +221,7 @@ class Optimizer:
|
|||
get_run_tmp_file.tmpdir.cleanup()
|
||||
|
||||
|
||||
|
||||
def run_with_args(args: Namespace) -> None:
|
||||
optimizer = Optimizer(args)
|
||||
optimizer.run()
|
||||
|
|
|
|||
|
|
@ -18,21 +18,19 @@ import marshal
|
|||
import os
|
||||
import pathlib
|
||||
import pickle
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
||||
from types import FrameType
|
||||
from typing import Any, ClassVar, List
|
||||
|
||||
import dill
|
||||
import isort
|
||||
from rich.align import Align
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from codeflash.cli_cmds.cli import project_root_from_module_root
|
||||
from codeflash.cli_cmds.console import console
|
||||
|
|
@ -42,34 +40,14 @@ from codeflash.discovery.functions_to_optimize import filter_files_optimized
|
|||
from codeflash.tracing.replay_test import create_trace_replay_test
|
||||
from codeflash.tracing.tracing_utils import FunctionModules
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType, TracebackType
|
||||
|
||||
|
||||
class FakeCode:
|
||||
def __init__(self, filename: str, line: int, name: str) -> None:
|
||||
self.co_filename = filename
|
||||
self.co_line = line
|
||||
self.co_name = name
|
||||
self.co_firstlineno = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr((self.co_filename, self.co_line, self.co_name, None))
|
||||
|
||||
|
||||
class FakeFrame:
|
||||
def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None:
|
||||
self.f_code = code
|
||||
self.f_back = prior
|
||||
self.f_locals: dict = {}
|
||||
|
||||
# import warnings
|
||||
# warnings.filterwarnings("ignore", category=dill.PickleWarning)
|
||||
# warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
|
||||
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
|
||||
class Tracer:
|
||||
"""Use this class as a 'with' context manager to trace a function call.
|
||||
|
||||
Traces function calls, input arguments, and profiling info.
|
||||
"""Use this class as a 'with' context manager to trace a function call,
|
||||
input arguments, and profiling info.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -81,9 +59,7 @@ class Tracer:
|
|||
max_function_count: int = 256,
|
||||
timeout: int | None = None, # seconds
|
||||
) -> None:
|
||||
"""Use this class to trace function calls.
|
||||
|
||||
:param output: The path to the output trace file
|
||||
""":param output: The path to the output trace file
|
||||
:param functions: List of functions to trace. If None, trace all functions
|
||||
:param disable: Disable the tracer if True
|
||||
:param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered
|
||||
|
|
@ -94,9 +70,7 @@ class Tracer:
|
|||
if functions is None:
|
||||
functions = []
|
||||
if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1":
|
||||
console.rule(
|
||||
"Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red"
|
||||
)
|
||||
console.print("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE")
|
||||
disable = True
|
||||
self.disable = disable
|
||||
if self.disable:
|
||||
|
|
@ -111,7 +85,7 @@ class Tracer:
|
|||
self.con = None
|
||||
self.output_file = Path(output).resolve()
|
||||
self.functions = functions
|
||||
self.function_modules: list[FunctionModules] = []
|
||||
self.function_modules: List[FunctionModules] = []
|
||||
self.function_count = defaultdict(int)
|
||||
self.current_file_path = Path(__file__).resolve()
|
||||
self.ignored_qualified_functions = {
|
||||
|
|
@ -121,10 +95,10 @@ class Tracer:
|
|||
self.max_function_count = max_function_count
|
||||
self.config, found_config_path = parse_config_file(config_file_path)
|
||||
self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path)
|
||||
console.rule(f"Project Root: {self.project_root}", style="bold blue")
|
||||
print("project_root", self.project_root)
|
||||
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
|
||||
|
||||
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001
|
||||
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
|
||||
|
||||
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
|
||||
self.timeout = timeout
|
||||
|
|
@ -145,44 +119,48 @@ class Tracer:
|
|||
def __enter__(self) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
if getattr(Tracer, "used_once", False):
|
||||
console.print(
|
||||
"Codeflash: Tracer can only be used once per program run. "
|
||||
"Please only enable the Tracer once. Skipping tracing this section."
|
||||
)
|
||||
self.disable = True
|
||||
return
|
||||
Tracer.used_once = True
|
||||
|
||||
# if getattr(Tracer, "used_once", False):
|
||||
# console.print(
|
||||
# "Codeflash: Tracer can only be used once per program run. "
|
||||
# "Please only enable the Tracer once. Skipping tracing this section."
|
||||
# )
|
||||
# self.disable = True
|
||||
# return
|
||||
# Tracer.used_once = True
|
||||
|
||||
if pathlib.Path(self.output_file).exists():
|
||||
console.rule("Removing existing trace file", style="bold red")
|
||||
console.rule()
|
||||
console.print("Codeflash: Removing existing trace file")
|
||||
pathlib.Path(self.output_file).unlink(missing_ok=True)
|
||||
|
||||
self.con = sqlite3.connect(self.output_file, check_same_thread=False)
|
||||
self.con = sqlite3.connect(self.output_file)
|
||||
cur = self.con.cursor()
|
||||
cur.execute("""PRAGMA synchronous = OFF""")
|
||||
cur.execute("""PRAGMA journal_mode = WAL""")
|
||||
# TODO: Check out if we need to export the function test name as well
|
||||
cur.execute(
|
||||
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
|
||||
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
|
||||
)
|
||||
console.rule("Codeflash: Traced Program Output Begin", style="bold blue")
|
||||
frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001
|
||||
console.print("Codeflash: Tracing started!")
|
||||
frame = sys._getframe(0) # Get this frame and simulate a call to it
|
||||
self.dispatch["call"](self, frame, 0)
|
||||
self.start_time = time.time()
|
||||
sys.setprofile(self.trace_callback)
|
||||
threading.setprofile(self.trace_callback)
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
sys.setprofile(None)
|
||||
self.con.commit()
|
||||
console.rule("Codeflash: Traced Program Output End", style="bold blue")
|
||||
# Check if any functions were actually traced
|
||||
if self.trace_count == 0:
|
||||
self.con.close()
|
||||
# Delete the trace file if no functions were traced
|
||||
if self.output_file.exists():
|
||||
self.output_file.unlink()
|
||||
console.print("Codeflash: No functions were traced. Removing trace database.")
|
||||
return
|
||||
|
||||
self.create_stats()
|
||||
|
||||
cur = self.con.cursor()
|
||||
|
|
@ -226,13 +204,14 @@ class Tracer:
|
|||
test_framework=self.config["test_framework"],
|
||||
max_run_count=self.max_function_count,
|
||||
)
|
||||
function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
|
||||
# Need a better way to store the replay test
|
||||
# function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
|
||||
function_path = self.file_being_called_from
|
||||
test_file_path = get_test_file_path(
|
||||
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
|
||||
)
|
||||
replay_test = isort.code(replay_test)
|
||||
|
||||
with Path(test_file_path).open("w", encoding="utf8") as file:
|
||||
with open(test_file_path, "w", encoding="utf8") as file:
|
||||
file.write(replay_test)
|
||||
|
||||
console.print(
|
||||
|
|
@ -242,27 +221,25 @@ class Tracer:
|
|||
overflow="ignore",
|
||||
)
|
||||
|
||||
def tracer_logic(self, frame: FrameType, event: str) -> None:
|
||||
def tracer_logic(self, frame: FrameType, event: str):
|
||||
if event != "call":
|
||||
return
|
||||
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
|
||||
sys.setprofile(None)
|
||||
threading.setprofile(None)
|
||||
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
|
||||
return
|
||||
if self.timeout is not None:
|
||||
if (time.time() - self.start_time) > self.timeout:
|
||||
sys.setprofile(None)
|
||||
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
|
||||
return
|
||||
code = frame.f_code
|
||||
|
||||
file_name = Path(code.co_filename).resolve()
|
||||
# TODO : It currently doesn't log the last return call from the first function
|
||||
|
||||
if code.co_name in self.ignored_functions:
|
||||
return
|
||||
if not file_name.is_relative_to(self.project_root):
|
||||
return
|
||||
if not file_name.exists():
|
||||
return
|
||||
if self.functions and code.co_name not in self.functions:
|
||||
return
|
||||
# if self.functions:
|
||||
# if code.co_name not in self.functions:
|
||||
# return
|
||||
class_name = None
|
||||
arguments = frame.f_locals
|
||||
try:
|
||||
|
|
@ -274,12 +251,16 @@ class Tracer:
|
|||
class_name = arguments["self"].__class__.__name__
|
||||
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
||||
class_name = arguments["cls"].__name__
|
||||
except: # noqa: E722
|
||||
except:
|
||||
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
|
||||
return
|
||||
|
||||
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
|
||||
if function_qualified_name in self.ignored_qualified_functions:
|
||||
return
|
||||
if self.functions and function_qualified_name not in self.functions:
|
||||
return
|
||||
|
||||
if function_qualified_name not in self.function_count:
|
||||
# seeing this function for the first time
|
||||
self.function_count[function_qualified_name] = 0
|
||||
|
|
@ -354,14 +335,17 @@ class Tracer:
|
|||
self.next_insert = 1000
|
||||
self.con.commit()
|
||||
|
||||
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
|
||||
def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None:
|
||||
# profiler section
|
||||
timer = self.timer
|
||||
t = timer() - self.t - self.bias
|
||||
if event == "c_call":
|
||||
self.c_func_name = arg.__name__
|
||||
|
||||
prof_success = bool(self.dispatch[event](self, frame, t))
|
||||
if self.dispatch[event](self, frame, t):
|
||||
prof_success = True
|
||||
else:
|
||||
prof_success = False
|
||||
# tracer section
|
||||
self.tracer_logic(frame, event)
|
||||
# measure the time as the last thing before return
|
||||
|
|
@ -370,60 +354,45 @@ class Tracer:
|
|||
else:
|
||||
self.t = timer() - t # put back unrecorded delta
|
||||
|
||||
def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
|
||||
"""Handle call events in the profiler."""
|
||||
def trace_dispatch_call(self, frame, t):
|
||||
if self.cur and frame.f_back is not self.cur[-2]:
|
||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||
if not isinstance(rframe, Tracer.fake_frame):
|
||||
assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back)
|
||||
self.trace_dispatch_return(rframe, 0)
|
||||
assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3])
|
||||
fcode = frame.f_code
|
||||
arguments = frame.f_locals
|
||||
class_name = None
|
||||
try:
|
||||
# In multi-threaded contexts, we need to be more careful about frame comparisons
|
||||
if self.cur and frame.f_back is not self.cur[-2]:
|
||||
# This happens when we're in a different thread
|
||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||
if (
|
||||
"self" in arguments
|
||||
and hasattr(arguments["self"], "__class__")
|
||||
and hasattr(arguments["self"].__class__, "__name__")
|
||||
):
|
||||
class_name = arguments["self"].__class__.__name__
|
||||
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
||||
class_name = arguments["cls"].__name__
|
||||
except:
|
||||
pass
|
||||
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
|
||||
self.cur = (t, 0, 0, fn, frame, self.cur)
|
||||
timings = self.timings
|
||||
if fn in timings:
|
||||
cc, ns, tt, ct, callers = timings[fn]
|
||||
timings[fn] = cc, ns + 1, tt, ct, callers
|
||||
else:
|
||||
timings[fn] = 0, 0, 0, 0, {}
|
||||
return 1
|
||||
|
||||
# Only attempt to handle the frame mismatch if we have a valid rframe
|
||||
if (
|
||||
not isinstance(rframe, FakeFrame)
|
||||
and hasattr(rframe, "f_back")
|
||||
and hasattr(frame, "f_back")
|
||||
and rframe.f_back is frame.f_back
|
||||
):
|
||||
self.trace_dispatch_return(rframe, 0)
|
||||
|
||||
# Get function information
|
||||
fcode = frame.f_code
|
||||
arguments = frame.f_locals
|
||||
class_name = None
|
||||
try:
|
||||
if (
|
||||
"self" in arguments
|
||||
and hasattr(arguments["self"], "__class__")
|
||||
and hasattr(arguments["self"].__class__, "__name__")
|
||||
):
|
||||
class_name = arguments["self"].__class__.__name__
|
||||
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
||||
class_name = arguments["cls"].__name__
|
||||
except Exception: # noqa: BLE001, S110
|
||||
pass
|
||||
|
||||
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
|
||||
self.cur = (t, 0, 0, fn, frame, self.cur)
|
||||
timings = self.timings
|
||||
if fn in timings:
|
||||
cc, ns, tt, ct, callers = timings[fn]
|
||||
timings[fn] = cc, ns + 1, tt, ct, callers
|
||||
else:
|
||||
timings[fn] = 0, 0, 0, 0, {}
|
||||
return 1 # noqa: TRY300
|
||||
except Exception: # noqa: BLE001
|
||||
# Handle any errors gracefully
|
||||
return 0
|
||||
|
||||
def trace_dispatch_exception(self, frame: FrameType, t: int) -> int:
|
||||
def trace_dispatch_exception(self, frame, t):
|
||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||
if (rframe is not frame) and rcur:
|
||||
return self.trace_dispatch_return(rframe, t)
|
||||
self.cur = rpt, rit + t, ret, rfn, rframe, rcur
|
||||
return 1
|
||||
|
||||
def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int:
|
||||
def trace_dispatch_c_call(self, frame, t):
|
||||
fn = ("", 0, self.c_func_name, None)
|
||||
self.cur = (t, 0, 0, fn, frame, self.cur)
|
||||
timings = self.timings
|
||||
|
|
@ -434,27 +403,15 @@ class Tracer:
|
|||
timings[fn] = 0, 0, 0, 0, {}
|
||||
return 1
|
||||
|
||||
def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
|
||||
if not self.cur or not self.cur[-2]:
|
||||
return 0
|
||||
|
||||
# In multi-threaded environments, frames can get mismatched
|
||||
def trace_dispatch_return(self, frame, t):
|
||||
if frame is not self.cur[-2]:
|
||||
# Don't assert in threaded environments - frames can legitimately differ
|
||||
if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back:
|
||||
self.trace_dispatch_return(self.cur[-2], 0)
|
||||
else:
|
||||
# We're in a different thread or context, can't continue with this frame
|
||||
return 0
|
||||
assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3])
|
||||
self.trace_dispatch_return(self.cur[-2], 0)
|
||||
|
||||
# Prefix "r" means part of the Returning or exiting frame.
|
||||
# Prefix "p" means part of the Previous or Parent or older frame.
|
||||
|
||||
rpt, rit, ret, rfn, frame, rcur = self.cur
|
||||
|
||||
# Guard against invalid rcur (w threading)
|
||||
if not rcur:
|
||||
return 0
|
||||
|
||||
rit = rit + t
|
||||
frame_total = rit + ret
|
||||
|
||||
|
|
@ -462,9 +419,6 @@ class Tracer:
|
|||
self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur
|
||||
|
||||
timings = self.timings
|
||||
if rfn not in timings:
|
||||
# w threading, rfn can be missing
|
||||
timings[rfn] = 0, 0, 0, 0, {}
|
||||
cc, ns, tt, ct, callers = timings[rfn]
|
||||
if not ns:
|
||||
# This is the only occurrence of the function on the stack.
|
||||
|
|
@ -486,7 +440,7 @@ class Tracer:
|
|||
|
||||
return 1
|
||||
|
||||
dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = {
|
||||
dispatch: ClassVar[dict[str, callable]] = {
|
||||
"call": trace_dispatch_call,
|
||||
"exception": trace_dispatch_exception,
|
||||
"return": trace_dispatch_return,
|
||||
|
|
@ -495,13 +449,32 @@ class Tracer:
|
|||
"c_return": trace_dispatch_return,
|
||||
}
|
||||
|
||||
def simulate_call(self, name: str) -> None:
|
||||
code = FakeCode("profiler", 0, name)
|
||||
pframe = self.cur[-2] if self.cur else None
|
||||
frame = FakeFrame(code, pframe)
|
||||
class fake_code:
|
||||
def __init__(self, filename, line, name):
|
||||
self.co_filename = filename
|
||||
self.co_line = line
|
||||
self.co_name = name
|
||||
self.co_firstlineno = 0
|
||||
|
||||
def __repr__(self):
|
||||
return repr((self.co_filename, self.co_line, self.co_name, None))
|
||||
|
||||
class fake_frame:
|
||||
def __init__(self, code, prior):
|
||||
self.f_code = code
|
||||
self.f_back = prior
|
||||
self.f_locals = {}
|
||||
|
||||
def simulate_call(self, name):
|
||||
code = self.fake_code("profiler", 0, name)
|
||||
if self.cur:
|
||||
pframe = self.cur[-2]
|
||||
else:
|
||||
pframe = None
|
||||
frame = self.fake_frame(code, pframe)
|
||||
self.dispatch["call"](self, frame, 0)
|
||||
|
||||
def simulate_cmd_complete(self) -> None:
|
||||
def simulate_cmd_complete(self):
|
||||
get_time = self.timer
|
||||
t = get_time() - self.t
|
||||
while self.cur[-1]:
|
||||
|
|
@ -511,174 +484,60 @@ class Tracer:
|
|||
t = 0
|
||||
self.t = get_time() - t
|
||||
|
||||
def print_stats(self, sort: str | int | tuple = -1) -> None:
|
||||
if not self.stats:
|
||||
console.print("Codeflash: No stats available to print")
|
||||
self.total_tt = 0
|
||||
return
|
||||
def print_stats(self, sort=-1):
|
||||
import pstats
|
||||
|
||||
if not isinstance(sort, tuple):
|
||||
sort = (sort,)
|
||||
# The following code customizes the default printing behavior to
|
||||
# print in milliseconds.
|
||||
s = StringIO()
|
||||
stats_obj = pstats.Stats(copy(self), stream=s)
|
||||
stats_obj.strip_dirs().sort_stats(*sort).print_stats(100)
|
||||
self.total_tt = stats_obj.total_tt
|
||||
console.print("total_tt", self.total_tt)
|
||||
raw_stats = s.getvalue()
|
||||
m = re.search(r"function calls?.*in (\d+)\.\d+ (seconds?)", raw_stats)
|
||||
total_time = None
|
||||
if m:
|
||||
total_time = int(m.group(1))
|
||||
if total_time is None:
|
||||
console.print("Failed to get total time from stats")
|
||||
total_time_ms = total_time / 1e6
|
||||
raw_stats = re.sub(
|
||||
r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats
|
||||
)
|
||||
match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +"
|
||||
m = re.findall(match_pattern, raw_stats, re.MULTILINE)
|
||||
ms_times = []
|
||||
for tottime, percall, cumtime, percall_cum in m:
|
||||
tottime_ms = int(tottime) / 1e6
|
||||
percall_ms = int(percall) / 1e6
|
||||
cumtime_ms = int(cumtime) / 1e6
|
||||
percall_cum_ms = int(percall_cum) / 1e6
|
||||
ms_times.append([tottime_ms, percall_ms, cumtime_ms, percall_cum_ms])
|
||||
split_stats = raw_stats.split("\n")
|
||||
new_stats = []
|
||||
|
||||
# First, convert stats to make them pstats-compatible
|
||||
try:
|
||||
# Initialize empty collections for pstats
|
||||
self.files = []
|
||||
self.top_level = []
|
||||
|
||||
# Create entirely new dictionaries instead of modifying existing ones
|
||||
new_stats = {}
|
||||
new_timings = {}
|
||||
|
||||
# Convert stats dictionary
|
||||
stats_items = list(self.stats.items())
|
||||
for func, stats_data in stats_items:
|
||||
try:
|
||||
# Make sure we have 5 elements in stats_data
|
||||
if len(stats_data) != 5:
|
||||
console.print(f"Skipping malformed stats data for {func}: {stats_data}")
|
||||
continue
|
||||
|
||||
cc, nc, tt, ct, callers = stats_data
|
||||
|
||||
if len(func) == 4:
|
||||
file_name, line_num, func_name, class_name = func
|
||||
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
|
||||
new_func = (file_name, line_num, new_func_name)
|
||||
else:
|
||||
new_func = func # Keep as is if already in correct format
|
||||
|
||||
new_callers = {}
|
||||
callers_items = list(callers.items())
|
||||
for caller_func, count in callers_items:
|
||||
if isinstance(caller_func, tuple):
|
||||
if len(caller_func) == 4:
|
||||
caller_file, caller_line, caller_name, caller_class = caller_func
|
||||
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
||||
new_caller_func = (caller_file, caller_line, caller_new_name)
|
||||
else:
|
||||
new_caller_func = caller_func
|
||||
else:
|
||||
console.print(f"Unexpected caller format: {caller_func}")
|
||||
new_caller_func = str(caller_func)
|
||||
|
||||
new_callers[new_caller_func] = count
|
||||
|
||||
# Store with new format
|
||||
new_stats[new_func] = (cc, nc, tt, ct, new_callers)
|
||||
except Exception as e: # noqa: BLE001
|
||||
console.print(f"Error converting stats for {func}: {e}")
|
||||
continue
|
||||
|
||||
timings_items = list(self.timings.items())
|
||||
for func, timing_data in timings_items:
|
||||
try:
|
||||
if len(timing_data) != 5:
|
||||
console.print(f"Skipping malformed timing data for {func}: {timing_data}")
|
||||
continue
|
||||
|
||||
cc, ns, tt, ct, callers = timing_data
|
||||
|
||||
if len(func) == 4:
|
||||
file_name, line_num, func_name, class_name = func
|
||||
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
|
||||
new_func = (file_name, line_num, new_func_name)
|
||||
else:
|
||||
new_func = func
|
||||
|
||||
new_callers = {}
|
||||
callers_items = list(callers.items())
|
||||
for caller_func, count in callers_items:
|
||||
if isinstance(caller_func, tuple):
|
||||
if len(caller_func) == 4:
|
||||
caller_file, caller_line, caller_name, caller_class = caller_func
|
||||
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
||||
new_caller_func = (caller_file, caller_line, caller_new_name)
|
||||
else:
|
||||
new_caller_func = caller_func
|
||||
else:
|
||||
console.print(f"Unexpected caller format: {caller_func}")
|
||||
new_caller_func = str(caller_func)
|
||||
|
||||
new_callers[new_caller_func] = count
|
||||
|
||||
new_timings[new_func] = (cc, ns, tt, ct, new_callers)
|
||||
except Exception as e: # noqa: BLE001
|
||||
console.print(f"Error converting timings for {func}: {e}")
|
||||
continue
|
||||
|
||||
self.stats = new_stats
|
||||
self.timings = new_timings
|
||||
|
||||
self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values())
|
||||
|
||||
total_calls = sum(cc for cc, _, _, _, _ in self.stats.values())
|
||||
total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values())
|
||||
|
||||
summary = Text.assemble(
|
||||
f"{total_calls:,} function calls ",
|
||||
("(" + f"{total_primitive:,} primitive calls" + ")", "dim"),
|
||||
f" in {self.total_tt / 1e6:.3f}milliseconds",
|
||||
)
|
||||
|
||||
console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False)))
|
||||
|
||||
table = Table(
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
border_style="blue",
|
||||
title="[bold]Function Profile[/bold] (ordered by internal time)",
|
||||
title_style="cyan",
|
||||
caption=f"Showing top 25 of {len(self.stats)} functions",
|
||||
)
|
||||
|
||||
table.add_column("Calls", justify="right", style="green", width=10)
|
||||
table.add_column("Time (ms)", justify="right", style="cyan", width=10)
|
||||
table.add_column("Per Call", justify="right", style="cyan", width=10)
|
||||
table.add_column("Cum (ms)", justify="right", style="yellow", width=10)
|
||||
table.add_column("Cum/Call", justify="right", style="yellow", width=10)
|
||||
table.add_column("Function", style="blue")
|
||||
|
||||
sorted_stats = sorted(
|
||||
((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3),
|
||||
key=lambda x: x[1][2], # Sort by tt (internal time)
|
||||
reverse=True,
|
||||
)[:25] # Limit to top 25
|
||||
|
||||
# Format and add each row to the table
|
||||
for func, (cc, nc, tt, ct, _) in sorted_stats:
|
||||
filename, lineno, funcname = func
|
||||
|
||||
# Format calls - show recursive format if different
|
||||
calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}"
|
||||
|
||||
# Convert to milliseconds
|
||||
tt_ms = tt / 1e6
|
||||
ct_ms = ct / 1e6
|
||||
|
||||
# Calculate per-call times
|
||||
per_call = tt_ms / cc if cc > 0 else 0
|
||||
cum_per_call = ct_ms / nc if nc > 0 else 0
|
||||
base_filename = Path(filename).name
|
||||
file_link = f"[link=file://{filename}]{base_filename}[/link]"
|
||||
|
||||
table.add_row(
|
||||
calls_str,
|
||||
f"{tt_ms:.3f}",
|
||||
f"{per_call:.3f}",
|
||||
f"{ct_ms:.3f}",
|
||||
f"{cum_per_call:.3f}",
|
||||
f"{funcname} [dim]({file_link}:{lineno})[/dim]",
|
||||
replace_pattern = r"^( *[\d\/]+) +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(.*)"
|
||||
times_index = 0
|
||||
for line in split_stats:
|
||||
if times_index >= len(ms_times):
|
||||
replaced = line
|
||||
else:
|
||||
replaced, n = re.subn(
|
||||
replace_pattern,
|
||||
rf"\g<1>{ms_times[times_index][0]:8.3f} {ms_times[times_index][1]:8.3f} {ms_times[times_index][2]:8.3f} {ms_times[times_index][3]:8.3f} \g<6>",
|
||||
line,
|
||||
count=1,
|
||||
)
|
||||
if n > 0:
|
||||
times_index += 1
|
||||
new_stats.append(replaced)
|
||||
|
||||
console.print(Align.center(table))
|
||||
console.print("\n".join(new_stats))
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
console.print(f"[bold red]Error in stats processing:[/bold red] {e}")
|
||||
console.print(f"Traced {self.trace_count:,} function calls")
|
||||
self.total_tt = 0
|
||||
|
||||
def make_pstats_compatible(self) -> None:
|
||||
def make_pstats_compatible(self):
|
||||
# delete the extra class_name item from the function tuple
|
||||
self.files = []
|
||||
self.top_level = []
|
||||
|
|
@ -693,33 +552,36 @@ class Tracer:
|
|||
self.stats = new_stats
|
||||
self.timings = new_timings
|
||||
|
||||
def dump_stats(self, file: str) -> None:
|
||||
with Path(file).open("wb") as f:
|
||||
def dump_stats(self, file):
|
||||
with open(file, "wb") as f:
|
||||
self.create_stats()
|
||||
marshal.dump(self.stats, f)
|
||||
|
||||
def create_stats(self) -> None:
|
||||
def create_stats(self):
|
||||
self.simulate_cmd_complete()
|
||||
self.snapshot_stats()
|
||||
|
||||
def snapshot_stats(self) -> None:
|
||||
def snapshot_stats(self):
|
||||
self.stats = {}
|
||||
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
|
||||
callers = caller_dict.copy()
|
||||
for func, (cc, ns, tt, ct, callers) in self.timings.items():
|
||||
callers = callers.copy()
|
||||
nc = 0
|
||||
for callcnt in callers.values():
|
||||
nc += callcnt
|
||||
self.stats[func] = cc, nc, tt, ct, callers
|
||||
|
||||
def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None:
|
||||
def runctx(self, cmd, globals, locals):
|
||||
self.__enter__()
|
||||
try:
|
||||
exec(cmd, global_vars, local_vars) # noqa: S102
|
||||
exec(cmd, globals, locals)
|
||||
finally:
|
||||
self.__exit__(None, None, None)
|
||||
return self
|
||||
|
||||
|
||||
def main() -> ArgumentParser:
|
||||
def main():
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
|
||||
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
|
||||
|
|
@ -776,13 +638,16 @@ def main() -> ArgumentParser:
|
|||
"__cached__": None,
|
||||
}
|
||||
try:
|
||||
Tracer(
|
||||
tracer = Tracer(
|
||||
output=args.outfile,
|
||||
functions=args.only_functions,
|
||||
max_function_count=args.max_function_count,
|
||||
timeout=args.tracer_timeout,
|
||||
config_file_path=args.codeflash_config,
|
||||
).runctx(code, globs, None)
|
||||
)
|
||||
|
||||
tracer.runctx(code, globs, None)
|
||||
print(tracer.functions)
|
||||
|
||||
except BrokenPipeError as exc:
|
||||
# Prevent "Exception ignored" during interpreter shutdown.
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ class TestType(Enum):
|
|||
REPLAY_TEST = 4
|
||||
CONCOLIC_COVERAGE_TEST = 5
|
||||
INIT_STATE_TEST = 6
|
||||
BENCHMARK_TEST = 7
|
||||
|
||||
def to_name(self) -> str:
|
||||
if self == TestType.INIT_STATE_TEST:
|
||||
|
|
@ -40,7 +39,6 @@ class TestType(Enum):
|
|||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
||||
TestType.BENCHMARK_TEST: "📏 Benchmark Tests",
|
||||
}
|
||||
return names[self]
|
||||
|
||||
|
|
|
|||
|
|
@ -75,3 +75,4 @@ class TestConfig:
|
|||
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
|
||||
concolic_test_root_dir: Optional[Path] = None
|
||||
pytest_cmd: str = "pytest"
|
||||
benchmark_tests_root: Optional[Path] = None
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ exclude = [
|
|||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
unidiff = ">=0.7.4"
|
||||
pytest = ">=7.0.0,<8.3.4"
|
||||
pytest = ">=7.0.0"
|
||||
gitpython = ">=3.1.31"
|
||||
libcst = ">=1.0.1"
|
||||
jedi = ">=0.19.1"
|
||||
|
|
@ -92,7 +92,6 @@ rich = ">=13.8.1"
|
|||
lxml = ">=5.3.0"
|
||||
crosshair-tool = ">=0.0.78"
|
||||
coverage = ">=7.6.4"
|
||||
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
|
|
@ -120,7 +119,7 @@ types-gevent = "^24.11.0.20241230"
|
|||
types-greenlet = "^3.1.0.20241221"
|
||||
types-pexpect = "^4.9.0.20241208"
|
||||
types-unidiff = "^0.7.0.20240505"
|
||||
uv = ">=0.6.2"
|
||||
sqlalchemy = "^2.0.38"
|
||||
|
||||
[tool.poetry.build]
|
||||
script = "codeflash/update_license_version.py"
|
||||
|
|
@ -152,7 +151,7 @@ warn_required_dynamic_aliases = true
|
|||
line-length = 120
|
||||
fix = true
|
||||
show-fixes = true
|
||||
exclude = ["code_to_optimize/", "pie_test_set/", "tests/"]
|
||||
exclude = ["code_to_optimize/", "pie_test_set/"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
|
|
@ -164,11 +163,10 @@ ignore = [
|
|||
"D103",
|
||||
"D105",
|
||||
"D107",
|
||||
"D203", # incorrect-blank-line-before-class (incompatible with D211)
|
||||
"D213", # multi-line-summary-second-line (incompatible with D212)
|
||||
"S101",
|
||||
"S603",
|
||||
"S607",
|
||||
"ANN101",
|
||||
"COM812",
|
||||
"FIX002",
|
||||
"PLR0912",
|
||||
|
|
@ -177,14 +175,13 @@ ignore = [
|
|||
"TD002",
|
||||
"TD003",
|
||||
"TD004",
|
||||
"PLR2004",
|
||||
"UP007" # remove once we drop 3.9 support.
|
||||
"PLR2004"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.flake8-type-checking]
|
||||
strict = true
|
||||
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
||||
runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"]
|
||||
runtime-evaluated-decorators = ["pydantic.validate_call"]
|
||||
|
||||
[tool.ruff.lint.pep8-naming]
|
||||
classmethod-decorators = [
|
||||
|
|
@ -192,9 +189,6 @@ classmethod-decorators = [
|
|||
"pydantic.validator",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
split-on-trailing-comma = false
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
skip-magic-trailing-comma = true
|
||||
|
|
@ -217,13 +211,13 @@ initial-content = """
|
|||
|
||||
|
||||
[tool.codeflash]
|
||||
module-root = "codeflash"
|
||||
tests-root = "tests"
|
||||
# All paths are relative to this pyproject.toml's directory.
|
||||
module-root = "code_to_optimize"
|
||||
tests-root = "code_to_optimize/tests"
|
||||
benchmarks-root = "code_to_optimize/tests/pytest/benchmarks"
|
||||
test-framework = "pytest"
|
||||
formatter-cmds = [
|
||||
"uvx ruff check --exit-zero --fix $file",
|
||||
"uvx ruff format $file",
|
||||
]
|
||||
ignore-paths = []
|
||||
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
|
||||
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
8
tests/test_trace_benchmarks.py
Normal file
8
tests/test_trace_benchmarks.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from pathlib import Path
|
||||
|
||||
def test_trace_benchmarks():
|
||||
# Test the trace_benchmarks function
|
||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks"
|
||||
trace_benchmarks_pytest(benchmarks_root, project_root, ["sorter"])
|
||||
|
|
@ -3,6 +3,7 @@ import tempfile
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.verification.test_results import TestType
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -21,7 +22,7 @@ def test_unit_test_discovery_pytest():
|
|||
|
||||
def test_benchmark_test_discovery_pytest():
|
||||
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
tests_path = project_path / "tests" / "pytest"
|
||||
tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_path,
|
||||
project_root_path=project_path,
|
||||
|
|
@ -29,9 +30,10 @@ def test_benchmark_test_discovery_pytest():
|
|||
tests_project_rootdir=tests_path.parent,
|
||||
)
|
||||
tests = discover_unit_tests(test_config)
|
||||
print(tests)
|
||||
assert len(tests) > 0
|
||||
# print(tests)
|
||||
assert 'bubble_sort.sorter' in tests
|
||||
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
|
||||
assert benchmark_tests == 1
|
||||
|
||||
|
||||
def test_unit_test_discovery_unittest():
|
||||
|
|
|
|||
Loading…
Reference in a new issue