mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
initial implementation for tracing benchmarks using a plugin, and projecting speedup
This commit is contained in:
parent
a73b541159
commit
965e2c818c
22 changed files with 790 additions and 571 deletions
28
code_to_optimize/process_and_bubble_sort.py
Normal file
28
code_to_optimize/process_and_bubble_sort.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
from code_to_optimize.bubble_sort import sorter
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_pairwise_products(arr):
|
||||||
|
"""
|
||||||
|
Calculate the average of all pairwise products in the array.
|
||||||
|
"""
|
||||||
|
sum_of_products = 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for i in range(len(arr)):
|
||||||
|
for j in range(len(arr)):
|
||||||
|
if i != j:
|
||||||
|
sum_of_products += arr[i] * arr[j]
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# The average of all pairwise products
|
||||||
|
return sum_of_products / count if count > 0 else 0
|
||||||
|
|
||||||
|
|
||||||
|
def compute_and_sort(arr):
|
||||||
|
# Compute pairwise sums average
|
||||||
|
pairwise_average = calculate_pairwise_products(arr)
|
||||||
|
|
||||||
|
# Call sorter function
|
||||||
|
sorter(arr.copy())
|
||||||
|
|
||||||
|
return pairwise_average
|
||||||
|
|
@ -1,6 +1,13 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
from code_to_optimize.bubble_sort import sorter
|
from code_to_optimize.bubble_sort import sorter
|
||||||
|
|
||||||
|
|
||||||
def test_sort(benchmark):
|
def test_sort(benchmark):
|
||||||
result = benchmark(sorter, list(reversed(range(5000))))
|
result = benchmark(sorter, list(reversed(range(5000))))
|
||||||
assert result == list(range(5000))
|
assert result == list(range(5000))
|
||||||
|
|
||||||
|
# This should not be picked up as a benchmark test
|
||||||
|
def test_sort2():
|
||||||
|
result = sorter(list(reversed(range(5000))))
|
||||||
|
assert result == list(range(5000))
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
from code_to_optimize.process_and_bubble_sort import compute_and_sort
|
||||||
|
from code_to_optimize.bubble_sort2 import sorter
|
||||||
|
def test_compute_and_sort(benchmark):
|
||||||
|
result = benchmark(compute_and_sort, list(reversed(range(5000))))
|
||||||
|
assert result == 6247083.5
|
||||||
|
|
||||||
|
def test_no_func(benchmark):
|
||||||
|
benchmark(sorter, list(reversed(range(5000))))
|
||||||
0
codeflash/benchmarking/__init__.py
Normal file
0
codeflash/benchmarking/__init__.py
Normal file
112
codeflash/benchmarking/get_trace_info.py
Normal file
112
codeflash/benchmarking/get_trace_info.py
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Set
|
||||||
|
|
||||||
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_benchmark_timings(trace_dir: Path, all_functions_to_optimize: list[FunctionToOptimize]) -> dict[str, dict[str, float]]:
|
||||||
|
"""Process all trace files in the given directory and extract timing data for the specified functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_dir: Path to the directory containing .trace files
|
||||||
|
all_functions_to_optimize: Set of FunctionToOptimize objects representing functions to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A nested dictionary where:
|
||||||
|
- Outer keys are function qualified names with file name
|
||||||
|
- Inner keys are benchmark names (trace filename without .trace extension)
|
||||||
|
- Values are function timing in milliseconds
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Create a mapping of (filename, function_name, class_name) -> qualified_name for efficient lookups
|
||||||
|
function_lookup = {}
|
||||||
|
function_benchmark_timings = {}
|
||||||
|
|
||||||
|
for func in all_functions_to_optimize:
|
||||||
|
qualified_name = func.qualified_name_with_file_name
|
||||||
|
|
||||||
|
# Extract components (assumes Path.name gives only filename without directory)
|
||||||
|
filename = func.file_path
|
||||||
|
function_name = func.function_name
|
||||||
|
|
||||||
|
# Get class name if there's a parent
|
||||||
|
class_name = func.parents[0].name if func.parents else None
|
||||||
|
|
||||||
|
# Store in lookup dictionary
|
||||||
|
key = (filename, function_name, class_name)
|
||||||
|
function_lookup[key] = qualified_name
|
||||||
|
function_benchmark_timings[qualified_name] = {}
|
||||||
|
|
||||||
|
# Find all .trace files in the directory
|
||||||
|
trace_files = list(trace_dir.glob("*.trace"))
|
||||||
|
|
||||||
|
for trace_file in trace_files:
|
||||||
|
# Extract benchmark name from filename (without .trace)
|
||||||
|
benchmark_name = trace_file.stem
|
||||||
|
|
||||||
|
# Connect to the trace database
|
||||||
|
conn = sqlite3.connect(trace_file)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# For each function we're interested in, query the database directly
|
||||||
|
for (filename, function_name, class_name), qualified_name in function_lookup.items():
|
||||||
|
# Adjust query based on whether we have a class name
|
||||||
|
if class_name:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND class_name = ?",
|
||||||
|
(f"%{filename}", function_name, class_name)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT total_time_ns FROM pstats WHERE filename LIKE ? AND function = ? AND (class_name IS NULL OR class_name = '')",
|
||||||
|
(f"%{filename}", function_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
time_ns = result[0]
|
||||||
|
function_benchmark_timings[qualified_name][benchmark_name] = time_ns / 1e6 # Convert to milliseconds
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return function_benchmark_timings
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark_timings(trace_dir: Path) -> dict[str, float]:
|
||||||
|
"""Extract total benchmark timings from trace files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trace_dir: Path to the directory containing .trace files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping benchmark names to their total execution time in milliseconds.
|
||||||
|
"""
|
||||||
|
benchmark_timings = {}
|
||||||
|
|
||||||
|
# Find all .trace files in the directory
|
||||||
|
trace_files = list(trace_dir.glob("*.trace"))
|
||||||
|
|
||||||
|
for trace_file in trace_files:
|
||||||
|
# Extract benchmark name from filename (without .trace extension)
|
||||||
|
benchmark_name = trace_file.stem
|
||||||
|
|
||||||
|
# Connect to the trace database
|
||||||
|
conn = sqlite3.connect(trace_file)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Query the total_time table for the benchmark's total execution time
|
||||||
|
try:
|
||||||
|
cursor.execute("SELECT time_ns FROM total_time")
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
time_ns = result[0]
|
||||||
|
# Convert nanoseconds to milliseconds
|
||||||
|
benchmark_timings[benchmark_name] = time_ns / 1e6
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
# Handle case where total_time table might not exist
|
||||||
|
print(f"Warning: Could not get total time for benchmark {benchmark_name}")
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return benchmark_timings
|
||||||
0
codeflash/benchmarking/plugin/__init__.py
Normal file
0
codeflash/benchmarking/plugin/__init__.py
Normal file
79
codeflash/benchmarking/plugin/plugin.py
Normal file
79
codeflash/benchmarking/plugin/plugin.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from codeflash.tracer import Tracer
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
class CodeFlashPlugin:
|
||||||
|
@staticmethod
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--codeflash-trace",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable CodeFlash tracing"
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--functions",
|
||||||
|
action="store",
|
||||||
|
default="",
|
||||||
|
help="Comma-separated list of additional functions to trace"
|
||||||
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--benchmarks-root",
|
||||||
|
action="store",
|
||||||
|
default=".",
|
||||||
|
help="Root directory for benchmarks"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pytest_plugin_registered(plugin, manager):
|
||||||
|
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
|
||||||
|
manager.unregister(plugin)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
if not config.getoption("--codeflash-trace"):
|
||||||
|
return
|
||||||
|
|
||||||
|
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||||
|
for item in items:
|
||||||
|
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
|
||||||
|
continue
|
||||||
|
item.add_marker(skip_no_benchmark)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@pytest.fixture
|
||||||
|
def benchmark(request):
|
||||||
|
if not request.config.getoption("--codeflash-trace"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class Benchmark:
|
||||||
|
def __call__(self, func, *args, **kwargs):
|
||||||
|
func_name = func.__name__
|
||||||
|
test_name = request.node.name
|
||||||
|
additional_functions = request.config.getoption("--functions").split(",")
|
||||||
|
trace_functions = [f for f in additional_functions if f]
|
||||||
|
print("Tracing functions: ", trace_functions)
|
||||||
|
|
||||||
|
# Get benchmarks root directory from command line option
|
||||||
|
benchmarks_root = Path(request.config.getoption("--benchmarks-root"))
|
||||||
|
|
||||||
|
# Create .trace directory if it doesn't exist
|
||||||
|
trace_dir = benchmarks_root / '.codeflash_trace'
|
||||||
|
trace_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Set output path to the .trace directory
|
||||||
|
output_path = trace_dir / f"{test_name}.trace"
|
||||||
|
|
||||||
|
tracer = Tracer(
|
||||||
|
output=str(output_path), # Convert Path to string for Tracer
|
||||||
|
functions=trace_functions,
|
||||||
|
max_function_count=256
|
||||||
|
)
|
||||||
|
|
||||||
|
with tracer:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return Benchmark()
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import sys
|
||||||
|
from plugin.plugin import CodeFlashPlugin
|
||||||
|
|
||||||
|
benchmarks_root = sys.argv[1]
|
||||||
|
function_list = sys.argv[2]
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
exitcode = pytest.main(
|
||||||
|
[benchmarks_root, "--benchmarks-root", benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "--functions", function_list], plugins=[CodeFlashPlugin()]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to collect tests: {e!s}")
|
||||||
|
exitcode = -1
|
||||||
20
codeflash/benchmarking/trace_benchmarks.py
Normal file
20
codeflash/benchmarking/trace_benchmarks.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||||
|
from pathlib import Path
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
def trace_benchmarks_pytest(benchmarks_root: Path, project_root: Path, function_list: list[str] = []) -> None:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
SAFE_SYS_EXECUTABLE,
|
||||||
|
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
|
||||||
|
str(benchmarks_root),
|
||||||
|
",".join(function_list)
|
||||||
|
],
|
||||||
|
cwd=project_root,
|
||||||
|
check=False,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
print("stdout:", result.stdout)
|
||||||
|
print("stderr:", result.stderr)
|
||||||
|
|
@ -62,6 +62,10 @@ def parse_args() -> Namespace:
|
||||||
)
|
)
|
||||||
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
|
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
|
||||||
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
|
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
|
||||||
|
parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks")
|
||||||
|
parser.add_argument(
|
||||||
|
"--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located."
|
||||||
|
)
|
||||||
args: Namespace = parser.parse_args()
|
args: Namespace = parser.parse_args()
|
||||||
return process_and_validate_cmd_args(args)
|
return process_and_validate_cmd_args(args)
|
||||||
|
|
||||||
|
|
@ -116,6 +120,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
||||||
"disable_telemetry",
|
"disable_telemetry",
|
||||||
"disable_imports_sorting",
|
"disable_imports_sorting",
|
||||||
"git_remote",
|
"git_remote",
|
||||||
|
"benchmarks_root"
|
||||||
]
|
]
|
||||||
for key in supported_keys:
|
for key in supported_keys:
|
||||||
if key in pyproject_config and (
|
if key in pyproject_config and (
|
||||||
|
|
@ -127,23 +132,17 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
||||||
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
||||||
assert args.tests_root is not None, "--tests-root must be specified"
|
assert args.tests_root is not None, "--tests-root must be specified"
|
||||||
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
|
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
|
||||||
|
if args.benchmark:
|
||||||
if env_utils.get_pr_number() is not None:
|
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
|
||||||
assert env_utils.ensure_codeflash_api_key(), (
|
assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
|
||||||
"Codeflash API key not found. When running in a Github Actions Context, provide the "
|
assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
|
||||||
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
|
"Codeflash API key not found. When running in a Github Actions Context, provide the "
|
||||||
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
|
"'CODEFLASH_API_KEY' environment variable as a secret.\n"
|
||||||
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
|
"You can add a secret by going to your repository's settings page, then clicking 'Secrets' in the left sidebar.\n"
|
||||||
f"Here's a direct link: {get_github_secrets_page_url()}\n"
|
"Then, click 'New repository secret' and add your api key with the variable name CODEFLASH_API_KEY.\n"
|
||||||
"Exiting..."
|
f"Here's a direct link: {get_github_secrets_page_url()}\n"
|
||||||
)
|
"Exiting..."
|
||||||
|
)
|
||||||
repo = git.Repo(search_parent_directories=True)
|
|
||||||
|
|
||||||
owner, repo_name = get_repo_owner_and_name(repo)
|
|
||||||
|
|
||||||
require_github_app_or_exit(owner, repo_name)
|
|
||||||
|
|
||||||
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
|
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
|
||||||
normalized_ignore_paths = []
|
normalized_ignore_paths = []
|
||||||
for path in args.ignore_paths:
|
for path in args.ignore_paths:
|
||||||
|
|
|
||||||
|
|
@ -107,8 +107,6 @@ def discover_tests_pytest(
|
||||||
test_type = TestType.REPLAY_TEST
|
test_type = TestType.REPLAY_TEST
|
||||||
elif "test_concolic_coverage" in test["test_file"]:
|
elif "test_concolic_coverage" in test["test_file"]:
|
||||||
test_type = TestType.CONCOLIC_COVERAGE_TEST
|
test_type = TestType.CONCOLIC_COVERAGE_TEST
|
||||||
elif test["test_type"] == "benchmark": # New condition for benchmark tests
|
|
||||||
test_type = TestType.BENCHMARK_TEST
|
|
||||||
else:
|
else:
|
||||||
test_type = TestType.EXISTING_UNIT_TEST
|
test_type = TestType.EXISTING_UNIT_TEST
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,6 @@ class FunctionToOptimize:
|
||||||
method extends this with the module name from the project root.
|
method extends this with the module name from the project root.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
function_name: str
|
function_name: str
|
||||||
file_path: Path
|
file_path: Path
|
||||||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||||
|
|
@ -145,6 +144,11 @@ class FunctionToOptimize:
|
||||||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def qualified_name_with_file_name(self) -> str:
|
||||||
|
class_name = self.parents[0].name if self.parents else None
|
||||||
|
return f"{self.file_path}:{(class_name + ':' if class_name else '')}{self.function_name}"
|
||||||
|
|
||||||
|
|
||||||
def get_functions_to_optimize(
|
def get_functions_to_optimize(
|
||||||
optimize_all: str | None,
|
optimize_all: str | None,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# This script should not have any relation to the codeflash package, be careful with imports
|
||||||
|
cwd = sys.argv[1]
|
||||||
|
tests_root = sys.argv[2]
|
||||||
|
pickle_path = sys.argv[3]
|
||||||
|
collected_tests = []
|
||||||
|
pytest_rootdir = None
|
||||||
|
sys.path.insert(1, str(cwd))
|
||||||
|
|
||||||
|
|
||||||
|
class PytestCollectionPlugin:
|
||||||
|
def pytest_collection_finish(self, session) -> None:
|
||||||
|
global pytest_rootdir
|
||||||
|
collected_tests.extend(session.items)
|
||||||
|
pytest_rootdir = session.config.rootdir
|
||||||
|
|
||||||
|
|
||||||
|
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
|
||||||
|
test_results = []
|
||||||
|
for test in pytest_tests:
|
||||||
|
test_class = None
|
||||||
|
if test.cls:
|
||||||
|
test_class = test.parent.name
|
||||||
|
|
||||||
|
# Determine if this is a benchmark test by checking for the benchmark fixture
|
||||||
|
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
|
||||||
|
test_type = 'benchmark' if is_benchmark else 'regular'
|
||||||
|
|
||||||
|
test_results.append({
|
||||||
|
"test_file": str(test.path),
|
||||||
|
"test_class": test_class,
|
||||||
|
"test_function": test.name,
|
||||||
|
"test_type": test_type
|
||||||
|
})
|
||||||
|
return test_results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
exitcode = pytest.main(
|
||||||
|
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to collect tests: {e!s}")
|
||||||
|
exitcode = -1
|
||||||
|
tests = parse_pytest_collection_results(collected_tests)
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
with open(pickle_path, "wb") as f:
|
||||||
|
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
@ -29,17 +29,7 @@ def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, s
|
||||||
test_class = None
|
test_class = None
|
||||||
if test.cls:
|
if test.cls:
|
||||||
test_class = test.parent.name
|
test_class = test.parent.name
|
||||||
|
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
|
||||||
# Determine if this is a benchmark test by checking for the benchmark fixture
|
|
||||||
is_benchmark = hasattr(test, 'fixturenames') and 'benchmark' in test.fixturenames
|
|
||||||
test_type = 'benchmark' if is_benchmark else 'regular'
|
|
||||||
|
|
||||||
test_results.append({
|
|
||||||
"test_file": str(test.path),
|
|
||||||
"test_class": test_class,
|
|
||||||
"test_function": test.name,
|
|
||||||
"test_type": test_type
|
|
||||||
})
|
|
||||||
return test_results
|
return test_results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
@ -21,12 +21,12 @@ from rich.tree import Tree
|
||||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||||
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
|
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
|
||||||
from codeflash.code_utils import env_utils
|
from codeflash.code_utils import env_utils
|
||||||
|
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
|
||||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||||
from codeflash.code_utils.code_utils import (
|
from codeflash.code_utils.code_utils import (
|
||||||
cleanup_paths,
|
cleanup_paths,
|
||||||
file_name_from_test_module_name,
|
file_name_from_test_module_name,
|
||||||
get_run_tmp_file,
|
get_run_tmp_file,
|
||||||
has_any_async_functions,
|
|
||||||
module_name_from_file_path,
|
module_name_from_file_path,
|
||||||
)
|
)
|
||||||
from codeflash.code_utils.config_consts import (
|
from codeflash.code_utils.config_consts import (
|
||||||
|
|
@ -37,7 +37,6 @@ from codeflash.code_utils.config_consts import (
|
||||||
)
|
)
|
||||||
from codeflash.code_utils.formatter import format_code, sort_imports
|
from codeflash.code_utils.formatter import format_code, sort_imports
|
||||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||||
from codeflash.code_utils.line_profile_utils import add_decorator_imports
|
|
||||||
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
|
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
|
||||||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||||
from codeflash.code_utils.time_utils import humanize_runtime
|
from codeflash.code_utils.time_utils import humanize_runtime
|
||||||
|
|
@ -49,6 +48,7 @@ from codeflash.models.models import (
|
||||||
BestOptimization,
|
BestOptimization,
|
||||||
CodeOptimizationContext,
|
CodeOptimizationContext,
|
||||||
FunctionCalledInTest,
|
FunctionCalledInTest,
|
||||||
|
FunctionParent,
|
||||||
GeneratedTests,
|
GeneratedTests,
|
||||||
GeneratedTestsList,
|
GeneratedTestsList,
|
||||||
OptimizationSet,
|
OptimizationSet,
|
||||||
|
|
@ -57,9 +57,8 @@ from codeflash.models.models import (
|
||||||
TestFile,
|
TestFile,
|
||||||
TestFiles,
|
TestFiles,
|
||||||
TestingMode,
|
TestingMode,
|
||||||
TestResults,
|
|
||||||
TestType,
|
|
||||||
)
|
)
|
||||||
|
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||||
from codeflash.result.explanation import Explanation
|
from codeflash.result.explanation import Explanation
|
||||||
|
|
@ -67,15 +66,18 @@ from codeflash.telemetry.posthog_cf import ph
|
||||||
from codeflash.verification.concolic_testing import generate_concolic_tests
|
from codeflash.verification.concolic_testing import generate_concolic_tests
|
||||||
from codeflash.verification.equivalence import compare_test_results
|
from codeflash.verification.equivalence import compare_test_results
|
||||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
|
||||||
from codeflash.verification.parse_test_output import parse_test_results
|
from codeflash.verification.parse_test_output import parse_test_results
|
||||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
from codeflash.verification.test_results import TestResults, TestType
|
||||||
|
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
|
||||||
from codeflash.verification.verification_utils import get_test_file_path
|
from codeflash.verification.verification_utils import get_test_file_path
|
||||||
from codeflash.verification.verifier import generate_tests
|
from codeflash.verification.verifier import generate_tests
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from codeflash.either import Result
|
from codeflash.either import Result
|
||||||
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
|
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
|
||||||
from codeflash.verification.verification_utils import TestConfig
|
from codeflash.verification.verification_utils import TestConfig
|
||||||
|
|
@ -90,6 +92,8 @@ class FunctionOptimizer:
|
||||||
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
||||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||||
aiservice_client: AiServiceClient | None = None,
|
aiservice_client: AiServiceClient | None = None,
|
||||||
|
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
|
||||||
|
total_benchmark_timings: dict[str, float] | None = None,
|
||||||
args: Namespace | None = None,
|
args: Namespace | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.project_root = test_cfg.project_root_path
|
self.project_root = test_cfg.project_root_path
|
||||||
|
|
@ -118,6 +122,9 @@ class FunctionOptimizer:
|
||||||
self.function_trace_id: str = str(uuid.uuid4())
|
self.function_trace_id: str = str(uuid.uuid4())
|
||||||
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
|
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
|
||||||
|
|
||||||
|
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
|
||||||
|
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
|
||||||
|
|
||||||
def optimize_function(self) -> Result[BestOptimization, str]:
|
def optimize_function(self) -> Result[BestOptimization, str]:
|
||||||
should_run_experiment = self.experiment_id is not None
|
should_run_experiment = self.experiment_id is not None
|
||||||
logger.debug(f"Function Trace ID: {self.function_trace_id}")
|
logger.debug(f"Function Trace ID: {self.function_trace_id}")
|
||||||
|
|
@ -134,10 +141,19 @@ class FunctionOptimizer:
|
||||||
with helper_function_path.open(encoding="utf8") as f:
|
with helper_function_path.open(encoding="utf8") as f:
|
||||||
helper_code = f.read()
|
helper_code = f.read()
|
||||||
original_helper_code[helper_function_path] = helper_code
|
original_helper_code[helper_function_path] = helper_code
|
||||||
if has_any_async_functions(code_context.read_writable_code):
|
|
||||||
return Failure("Codeflash does not support async functions in the code to optimize.")
|
logger.info("Code to be optimized:")
|
||||||
code_print(code_context.read_writable_code)
|
code_print(code_context.read_writable_code)
|
||||||
|
|
||||||
|
for module_abspath, helper_code_source in original_helper_code.items():
|
||||||
|
code_context.code_to_optimize_with_helpers = add_needed_imports_from_module(
|
||||||
|
helper_code_source,
|
||||||
|
code_context.code_to_optimize_with_helpers,
|
||||||
|
module_abspath,
|
||||||
|
self.function_to_optimize.file_path,
|
||||||
|
self.args.project_root,
|
||||||
|
)
|
||||||
|
|
||||||
generated_test_paths = [
|
generated_test_paths = [
|
||||||
get_test_file_path(
|
get_test_file_path(
|
||||||
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
|
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
|
||||||
|
|
@ -156,7 +172,7 @@ class FunctionOptimizer:
|
||||||
transient=True,
|
transient=True,
|
||||||
):
|
):
|
||||||
generated_results = self.generate_tests_and_optimizations(
|
generated_results = self.generate_tests_and_optimizations(
|
||||||
testgen_context_code=code_context.testgen_context_code,
|
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
|
||||||
read_writable_code=code_context.read_writable_code,
|
read_writable_code=code_context.read_writable_code,
|
||||||
read_only_context_code=code_context.read_only_context_code,
|
read_only_context_code=code_context.read_only_context_code,
|
||||||
helper_functions=code_context.helper_functions,
|
helper_functions=code_context.helper_functions,
|
||||||
|
|
@ -232,11 +248,10 @@ class FunctionOptimizer:
|
||||||
):
|
):
|
||||||
cleanup_paths(paths_to_cleanup)
|
cleanup_paths(paths_to_cleanup)
|
||||||
return Failure("The threshold for test coverage was not met.")
|
return Failure("The threshold for test coverage was not met.")
|
||||||
# request for new optimizations but don't block execution, check for completion later
|
|
||||||
# adding to control and experiment set but with same traceid
|
|
||||||
best_optimization = None
|
best_optimization = None
|
||||||
|
|
||||||
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
|
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
|
||||||
if candidates is None:
|
if candidates is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -270,6 +285,20 @@ class FunctionOptimizer:
|
||||||
function_name=function_to_optimize_qualified_name,
|
function_name=function_to_optimize_qualified_name,
|
||||||
file_path=self.function_to_optimize.file_path,
|
file_path=self.function_to_optimize.file_path,
|
||||||
)
|
)
|
||||||
|
speedup = explanation.speedup # eg. 1.2 means 1.2x faster
|
||||||
|
if self.args.benchmark:
|
||||||
|
fto_benchmark_timings = self.function_benchmark_timings[self.function_to_optimize.qualified_name_with_file_name]
|
||||||
|
for benchmark_name, og_benchmark_timing in fto_benchmark_timings.items():
|
||||||
|
print(f"Calculating speedup for benchmark {benchmark_name}")
|
||||||
|
total_benchmark_timing = self.total_benchmark_timings[benchmark_name]
|
||||||
|
# find out expected new benchmark timing, then calculate how much total benchmark was sped up. print out intermediate values
|
||||||
|
expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + og_benchmark_timing / speedup
|
||||||
|
print(f"Expected new benchmark timing: {expected_new_benchmark_timing}")
|
||||||
|
print(f"Original benchmark timing: {total_benchmark_timing}")
|
||||||
|
print(f"Benchmark speedup: {total_benchmark_timing / expected_new_benchmark_timing}")
|
||||||
|
|
||||||
|
speedup = total_benchmark_timing / expected_new_benchmark_timing
|
||||||
|
print(f"Speedup: {speedup}")
|
||||||
|
|
||||||
self.log_successful_optimization(explanation, generated_tests)
|
self.log_successful_optimization(explanation, generated_tests)
|
||||||
|
|
||||||
|
|
@ -359,123 +388,94 @@ class FunctionOptimizer:
|
||||||
f"{self.function_to_optimize.qualified_name}…"
|
f"{self.function_to_optimize.qualified_name}…"
|
||||||
)
|
)
|
||||||
console.rule()
|
console.rule()
|
||||||
candidates = deque(candidates)
|
try:
|
||||||
# Start a new thread for AI service request, start loop in main thread
|
for candidate_index, candidate in enumerate(candidates, start=1):
|
||||||
# check if aiservice request is complete, when it is complete, append result to the candidates list
|
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||||
future_line_profile_results = executor.submit(
|
logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
|
||||||
self.aiservice_client.optimize_python_code_line_profiler,
|
code_print(candidate.source_code)
|
||||||
source_code=code_context.read_writable_code,
|
try:
|
||||||
dependency_code=code_context.read_only_context_code,
|
did_update = self.replace_function_and_helpers_with_optimized_code(
|
||||||
trace_id=self.function_trace_id,
|
code_context=code_context, optimized_code=candidate.source_code
|
||||||
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
|
|
||||||
num_candidates=10,
|
|
||||||
experiment_metadata=None,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
candidate_index = 0
|
|
||||||
done = False
|
|
||||||
original_len = len(candidates)
|
|
||||||
while candidates:
|
|
||||||
# for candidate_index, candidate in enumerate(candidates, start=1):
|
|
||||||
done = True if future_line_profile_results is None else future_line_profile_results.done()
|
|
||||||
if done and (future_line_profile_results is not None):
|
|
||||||
line_profile_results = future_line_profile_results.result()
|
|
||||||
candidates.extend(line_profile_results)
|
|
||||||
original_len+= len(line_profile_results)
|
|
||||||
logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}")
|
|
||||||
future_line_profile_results = None
|
|
||||||
candidate_index += 1
|
|
||||||
candidate = candidates.popleft()
|
|
||||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
|
||||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
|
||||||
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
|
|
||||||
code_print(candidate.source_code)
|
|
||||||
try:
|
|
||||||
did_update = self.replace_function_and_helpers_with_optimized_code(
|
|
||||||
code_context=code_context, optimized_code=candidate.source_code
|
|
||||||
)
|
|
||||||
if not did_update:
|
|
||||||
logger.warning(
|
|
||||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
|
||||||
)
|
|
||||||
console.rule()
|
|
||||||
continue
|
|
||||||
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
|
||||||
logger.error(e)
|
|
||||||
self.write_code_and_helpers(
|
|
||||||
self.function_to_optimize_source_code,
|
|
||||||
original_helper_code,
|
|
||||||
self.function_to_optimize.file_path,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Instrument codeflash capture
|
|
||||||
run_results = self.run_optimized_candidate(
|
|
||||||
optimization_candidate_index=candidate_index,
|
|
||||||
baseline_results=original_code_baseline,
|
|
||||||
original_helper_code=original_helper_code,
|
|
||||||
file_path_to_helper_classes=file_path_to_helper_classes,
|
|
||||||
)
|
)
|
||||||
console.rule()
|
if not did_update:
|
||||||
|
logger.warning(
|
||||||
if not is_successful(run_results):
|
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||||
optimized_runtimes[candidate.optimization_id] = None
|
|
||||||
is_correct[candidate.optimization_id] = False
|
|
||||||
speedup_ratios[candidate.optimization_id] = None
|
|
||||||
else:
|
|
||||||
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
|
||||||
best_test_runtime = candidate_result.best_test_runtime
|
|
||||||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
|
||||||
is_correct[candidate.optimization_id] = True
|
|
||||||
perf_gain = performance_gain(
|
|
||||||
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
|
|
||||||
)
|
)
|
||||||
speedup_ratios[candidate.optimization_id] = perf_gain
|
|
||||||
|
|
||||||
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
|
|
||||||
if speedup_critic(
|
|
||||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now
|
|
||||||
) and quantity_of_tests_critic(candidate_result):
|
|
||||||
tree.add("This candidate is faster than the previous best candidate. 🚀")
|
|
||||||
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
|
||||||
tree.add(
|
|
||||||
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
|
|
||||||
f"(measured over {candidate_result.max_loop_count} "
|
|
||||||
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
|
||||||
)
|
|
||||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
|
||||||
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
|
|
||||||
|
|
||||||
best_optimization = BestOptimization(
|
|
||||||
candidate=candidate,
|
|
||||||
helper_functions=code_context.helper_functions,
|
|
||||||
runtime=best_test_runtime,
|
|
||||||
winning_behavioral_test_results=candidate_result.behavior_test_results,
|
|
||||||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
|
||||||
)
|
|
||||||
best_runtime_until_now = best_test_runtime
|
|
||||||
else:
|
|
||||||
tree.add(
|
|
||||||
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
|
|
||||||
f"(measured over {candidate_result.max_loop_count} "
|
|
||||||
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
|
||||||
)
|
|
||||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
|
||||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
|
||||||
console.print(tree)
|
|
||||||
console.rule()
|
console.rule()
|
||||||
|
continue
|
||||||
|
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
||||||
|
logger.error(e)
|
||||||
self.write_code_and_helpers(
|
self.write_code_and_helpers(
|
||||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Instrument codeflash capture
|
||||||
|
run_results = self.run_optimized_candidate(
|
||||||
|
optimization_candidate_index=candidate_index,
|
||||||
|
baseline_results=original_code_baseline,
|
||||||
|
original_helper_code=original_helper_code,
|
||||||
|
file_path_to_helper_classes=file_path_to_helper_classes,
|
||||||
|
)
|
||||||
|
console.rule()
|
||||||
|
|
||||||
|
if not is_successful(run_results):
|
||||||
|
optimized_runtimes[candidate.optimization_id] = None
|
||||||
|
is_correct[candidate.optimization_id] = False
|
||||||
|
speedup_ratios[candidate.optimization_id] = None
|
||||||
|
else:
|
||||||
|
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
||||||
|
best_test_runtime = candidate_result.best_test_runtime
|
||||||
|
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||||
|
is_correct[candidate.optimization_id] = True
|
||||||
|
perf_gain = performance_gain(
|
||||||
|
original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime
|
||||||
|
)
|
||||||
|
speedup_ratios[candidate.optimization_id] = perf_gain
|
||||||
|
|
||||||
|
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
|
||||||
|
if speedup_critic(
|
||||||
|
candidate_result, original_code_baseline.runtime, best_runtime_until_now
|
||||||
|
) and quantity_of_tests_critic(candidate_result):
|
||||||
|
tree.add("This candidate is faster than the previous best candidate. 🚀")
|
||||||
|
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")
|
||||||
|
tree.add(
|
||||||
|
f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} "
|
||||||
|
f"(measured over {candidate_result.max_loop_count} "
|
||||||
|
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||||
|
)
|
||||||
|
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||||
|
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
|
||||||
|
|
||||||
|
best_optimization = BestOptimization(
|
||||||
|
candidate=candidate,
|
||||||
|
helper_functions=code_context.helper_functions,
|
||||||
|
runtime=best_test_runtime,
|
||||||
|
winning_behavioral_test_results=candidate_result.behavior_test_results,
|
||||||
|
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||||
|
)
|
||||||
|
best_runtime_until_now = best_test_runtime
|
||||||
|
else:
|
||||||
|
tree.add(
|
||||||
|
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
|
||||||
|
f"(measured over {candidate_result.max_loop_count} "
|
||||||
|
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
|
||||||
|
)
|
||||||
|
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||||
|
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||||
|
console.print(tree)
|
||||||
|
console.rule()
|
||||||
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
self.write_code_and_helpers(
|
self.write_code_and_helpers(
|
||||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||||
)
|
)
|
||||||
logger.exception(f"Optimization interrupted: {e}")
|
except KeyboardInterrupt as e:
|
||||||
raise
|
self.write_code_and_helpers(
|
||||||
|
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||||
|
)
|
||||||
|
logger.exception(f"Optimization interrupted: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
self.aiservice_client.log_results(
|
self.aiservice_client.log_results(
|
||||||
function_trace_id=self.function_trace_id,
|
function_trace_id=self.function_trace_id,
|
||||||
|
|
@ -575,6 +575,50 @@ class FunctionOptimizer:
|
||||||
return did_update
|
return did_update
|
||||||
|
|
||||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||||
|
code_to_optimize, contextual_dunder_methods = extract_code([self.function_to_optimize])
|
||||||
|
if code_to_optimize is None:
|
||||||
|
return Failure("Could not find function to optimize.")
|
||||||
|
(helper_code, helper_functions, helper_dunder_methods) = get_constrained_function_context_and_helper_functions(
|
||||||
|
self.function_to_optimize, self.project_root, code_to_optimize
|
||||||
|
)
|
||||||
|
if self.function_to_optimize.parents:
|
||||||
|
function_class = self.function_to_optimize.parents[0].name
|
||||||
|
same_class_helper_methods = [
|
||||||
|
df
|
||||||
|
for df in helper_functions
|
||||||
|
if df.qualified_name.count(".") > 0 and df.qualified_name.split(".")[0] == function_class
|
||||||
|
]
|
||||||
|
optimizable_methods = [
|
||||||
|
FunctionToOptimize(
|
||||||
|
df.qualified_name.split(".")[-1],
|
||||||
|
df.file_path,
|
||||||
|
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
for df in same_class_helper_methods
|
||||||
|
] + [self.function_to_optimize]
|
||||||
|
dedup_optimizable_methods = []
|
||||||
|
added_methods = set()
|
||||||
|
for method in reversed(optimizable_methods):
|
||||||
|
if f"{method.file_path}.{method.qualified_name}" not in added_methods:
|
||||||
|
dedup_optimizable_methods.append(method)
|
||||||
|
added_methods.add(f"{method.file_path}.{method.qualified_name}")
|
||||||
|
if len(dedup_optimizable_methods) > 1:
|
||||||
|
code_to_optimize, contextual_dunder_methods = extract_code(list(reversed(dedup_optimizable_methods)))
|
||||||
|
if code_to_optimize is None:
|
||||||
|
return Failure("Could not find function to optimize.")
|
||||||
|
code_to_optimize_with_helpers = helper_code + "\n" + code_to_optimize
|
||||||
|
|
||||||
|
code_to_optimize_with_helpers_and_imports = add_needed_imports_from_module(
|
||||||
|
self.function_to_optimize_source_code,
|
||||||
|
code_to_optimize_with_helpers,
|
||||||
|
self.function_to_optimize.file_path,
|
||||||
|
self.function_to_optimize.file_path,
|
||||||
|
self.project_root,
|
||||||
|
helper_functions,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_code_ctx = code_context_extractor.get_code_optimization_context(
|
new_code_ctx = code_context_extractor.get_code_optimization_context(
|
||||||
self.function_to_optimize, self.project_root
|
self.function_to_optimize, self.project_root
|
||||||
|
|
@ -584,7 +628,7 @@ class FunctionOptimizer:
|
||||||
|
|
||||||
return Success(
|
return Success(
|
||||||
CodeOptimizationContext(
|
CodeOptimizationContext(
|
||||||
testgen_context_code=new_code_ctx.testgen_context_code,
|
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||||
read_writable_code=new_code_ctx.read_writable_code,
|
read_writable_code=new_code_ctx.read_writable_code,
|
||||||
read_only_context_code=new_code_ctx.read_only_context_code,
|
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||||
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
|
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
|
||||||
|
|
@ -686,7 +730,7 @@ class FunctionOptimizer:
|
||||||
|
|
||||||
def generate_tests_and_optimizations(
|
def generate_tests_and_optimizations(
|
||||||
self,
|
self,
|
||||||
testgen_context_code: str,
|
code_to_optimize_with_helpers: str,
|
||||||
read_writable_code: str,
|
read_writable_code: str,
|
||||||
read_only_context_code: str,
|
read_only_context_code: str,
|
||||||
helper_functions: list[FunctionSource],
|
helper_functions: list[FunctionSource],
|
||||||
|
|
@ -701,7 +745,7 @@ class FunctionOptimizer:
|
||||||
# Submit the test generation task as future
|
# Submit the test generation task as future
|
||||||
future_tests = self.generate_and_instrument_tests(
|
future_tests = self.generate_and_instrument_tests(
|
||||||
executor,
|
executor,
|
||||||
testgen_context_code,
|
code_to_optimize_with_helpers,
|
||||||
[definition.fully_qualified_name for definition in helper_functions],
|
[definition.fully_qualified_name for definition in helper_functions],
|
||||||
generated_test_paths,
|
generated_test_paths,
|
||||||
generated_perf_test_paths,
|
generated_perf_test_paths,
|
||||||
|
|
@ -790,7 +834,6 @@ class FunctionOptimizer:
|
||||||
original_helper_code: dict[Path, str],
|
original_helper_code: dict[Path, str],
|
||||||
file_path_to_helper_classes: dict[Path, set[str]],
|
file_path_to_helper_classes: dict[Path, set[str]],
|
||||||
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
|
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
|
||||||
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
|
|
||||||
# For the original function - run the tests and get the runtime, plus coverage
|
# For the original function - run the tests and get the runtime, plus coverage
|
||||||
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
|
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
|
||||||
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
|
||||||
|
|
@ -831,31 +874,11 @@ class FunctionOptimizer:
|
||||||
)
|
)
|
||||||
console.rule()
|
console.rule()
|
||||||
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
|
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
|
||||||
if not coverage_critic(coverage_results, self.args.test_framework):
|
if not coverage_critic(
|
||||||
|
coverage_results, self.args.test_framework
|
||||||
|
):
|
||||||
return Failure("The threshold for test coverage was not met.")
|
return Failure("The threshold for test coverage was not met.")
|
||||||
if test_framework == "pytest":
|
if test_framework == "pytest":
|
||||||
try:
|
|
||||||
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
|
|
||||||
line_profile_results, _ = self.run_and_parse_tests(
|
|
||||||
testing_type=TestingMode.LINE_PROFILE,
|
|
||||||
test_env=test_env,
|
|
||||||
test_files=self.test_files,
|
|
||||||
optimization_iteration=0,
|
|
||||||
testing_time=TOTAL_LOOPING_TIME,
|
|
||||||
enable_coverage=False,
|
|
||||||
code_context=code_context,
|
|
||||||
line_profiler_output_file=line_profiler_output_file,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
# Remove codeflash capture
|
|
||||||
self.write_code_and_helpers(
|
|
||||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
|
||||||
)
|
|
||||||
if line_profile_results["str_out"] == "":
|
|
||||||
logger.warning(
|
|
||||||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
|
||||||
)
|
|
||||||
console.rule()
|
|
||||||
benchmarking_results, _ = self.run_and_parse_tests(
|
benchmarking_results, _ = self.run_and_parse_tests(
|
||||||
testing_type=TestingMode.PERFORMANCE,
|
testing_type=TestingMode.PERFORMANCE,
|
||||||
test_env=test_env,
|
test_env=test_env,
|
||||||
|
|
@ -894,6 +917,7 @@ class FunctionOptimizer:
|
||||||
)
|
)
|
||||||
console.rule()
|
console.rule()
|
||||||
|
|
||||||
|
|
||||||
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
|
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
|
||||||
functions_to_remove = [
|
functions_to_remove = [
|
||||||
result.id.test_function_name
|
result.id.test_function_name
|
||||||
|
|
@ -927,7 +951,6 @@ class FunctionOptimizer:
|
||||||
benchmarking_test_results=benchmarking_results,
|
benchmarking_test_results=benchmarking_results,
|
||||||
runtime=total_timing,
|
runtime=total_timing,
|
||||||
coverage_results=coverage_results,
|
coverage_results=coverage_results,
|
||||||
line_profile_results=line_profile_results,
|
|
||||||
),
|
),
|
||||||
functions_to_remove,
|
functions_to_remove,
|
||||||
)
|
)
|
||||||
|
|
@ -1063,77 +1086,59 @@ class FunctionOptimizer:
|
||||||
pytest_max_loops: int = 100_000,
|
pytest_max_loops: int = 100_000,
|
||||||
code_context: CodeOptimizationContext | None = None,
|
code_context: CodeOptimizationContext | None = None,
|
||||||
unittest_loop_index: int | None = None,
|
unittest_loop_index: int | None = None,
|
||||||
line_profiler_output_file: Path | None = None,
|
) -> tuple[TestResults, CoverageData | None]:
|
||||||
) -> tuple[TestResults | dict, CoverageData | None]:
|
|
||||||
coverage_database_file = None
|
coverage_database_file = None
|
||||||
coverage_config_file = None
|
|
||||||
try:
|
try:
|
||||||
if testing_type == TestingMode.BEHAVIOR:
|
if testing_type == TestingMode.BEHAVIOR:
|
||||||
result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests(
|
result_file_path, run_result, coverage_database_file = run_behavioral_tests(
|
||||||
test_files,
|
test_files,
|
||||||
test_framework=self.test_cfg.test_framework,
|
test_framework=self.test_cfg.test_framework,
|
||||||
cwd=self.project_root,
|
cwd=self.project_root,
|
||||||
test_env=test_env,
|
test_env=test_env,
|
||||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||||
|
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
enable_coverage=enable_coverage,
|
enable_coverage=enable_coverage,
|
||||||
)
|
)
|
||||||
elif testing_type == TestingMode.LINE_PROFILE:
|
|
||||||
result_file_path, run_result = run_line_profile_tests(
|
|
||||||
test_files,
|
|
||||||
cwd=self.project_root,
|
|
||||||
test_env=test_env,
|
|
||||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
|
||||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
|
||||||
pytest_target_runtime_seconds=testing_time,
|
|
||||||
pytest_min_loops=pytest_min_loops,
|
|
||||||
pytest_max_loops=pytest_min_loops,
|
|
||||||
test_framework=self.test_cfg.test_framework,
|
|
||||||
line_profiler_output_file=line_profiler_output_file,
|
|
||||||
)
|
|
||||||
elif testing_type == TestingMode.PERFORMANCE:
|
elif testing_type == TestingMode.PERFORMANCE:
|
||||||
result_file_path, run_result = run_benchmarking_tests(
|
result_file_path, run_result = run_benchmarking_tests(
|
||||||
test_files,
|
test_files,
|
||||||
cwd=self.project_root,
|
cwd=self.project_root,
|
||||||
test_env=test_env,
|
test_env=test_env,
|
||||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
|
||||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||||
|
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||||
pytest_target_runtime_seconds=testing_time,
|
pytest_target_runtime_seconds=testing_time,
|
||||||
pytest_min_loops=pytest_min_loops,
|
pytest_min_loops=pytest_min_loops,
|
||||||
pytest_max_loops=pytest_max_loops,
|
pytest_max_loops=pytest_max_loops,
|
||||||
test_framework=self.test_cfg.test_framework,
|
test_framework=self.test_cfg.test_framework,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg = f"Unexpected testing type: {testing_type}"
|
raise ValueError(f"Unexpected testing type: {testing_type}")
|
||||||
raise ValueError(msg)
|
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
|
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
|
||||||
)
|
)
|
||||||
return TestResults(), None
|
return TestResults(), None
|
||||||
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
|
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Nonzero return code {run_result.returncode} when running tests in "
|
f'Nonzero return code {run_result.returncode} when running tests in '
|
||||||
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n"
|
f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n'
|
||||||
f"stdout: {run_result.stdout}\n"
|
f"stdout: {run_result.stdout}\n"
|
||||||
f"stderr: {run_result.stderr}\n"
|
f"stderr: {run_result.stderr}\n"
|
||||||
)
|
)
|
||||||
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
|
# print(test_files)
|
||||||
results, coverage_results = parse_test_results(
|
results, coverage_results = parse_test_results(
|
||||||
test_xml_path=result_file_path,
|
test_xml_path=result_file_path,
|
||||||
test_files=test_files,
|
test_files=test_files,
|
||||||
test_config=self.test_cfg,
|
test_config=self.test_cfg,
|
||||||
optimization_iteration=optimization_iteration,
|
optimization_iteration=optimization_iteration,
|
||||||
run_result=run_result,
|
run_result=run_result,
|
||||||
unittest_loop_index=unittest_loop_index,
|
unittest_loop_index=unittest_loop_index,
|
||||||
function_name=self.function_to_optimize.function_name,
|
function_name=self.function_to_optimize.function_name,
|
||||||
source_file=self.function_to_optimize.file_path,
|
source_file=self.function_to_optimize.file_path,
|
||||||
code_context=code_context,
|
code_context=code_context,
|
||||||
coverage_database_file=coverage_database_file,
|
coverage_database_file=coverage_database_file,
|
||||||
coverage_config_file=coverage_config_file,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
|
||||||
return results, coverage_results
|
return results, coverage_results
|
||||||
|
|
||||||
def generate_and_instrument_tests(
|
def generate_and_instrument_tests(
|
||||||
|
|
@ -1163,3 +1168,4 @@ class FunctionOptimizer:
|
||||||
zip(generated_test_paths, generated_perf_test_paths)
|
zip(generated_test_paths, generated_perf_test_paths)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||||
from codeflash.cli_cmds.console import console, logger, progress_bar
|
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||||
|
from codeflash.cli_cmds.console import console, logger
|
||||||
from codeflash.code_utils import env_utils
|
from codeflash.code_utils import env_utils
|
||||||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
|
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
|
||||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||||
|
|
@ -16,10 +17,12 @@ from codeflash.code_utils.static_analysis import analyze_imported_modules, get_f
|
||||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||||
from codeflash.either import is_successful
|
from codeflash.either import is_successful
|
||||||
from codeflash.models.models import TestType, ValidCode
|
from codeflash.models.models import TestFiles, ValidCode
|
||||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||||
from codeflash.telemetry.posthog_cf import ph
|
from codeflash.telemetry.posthog_cf import ph
|
||||||
|
from codeflash.verification.test_results import TestType
|
||||||
from codeflash.verification.verification_utils import TestConfig
|
from codeflash.verification.verification_utils import TestConfig
|
||||||
|
from codeflash.benchmarking.get_trace_info import get_function_benchmark_timings, get_benchmark_timings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
@ -50,6 +53,8 @@ class Optimizer:
|
||||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||||
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
||||||
function_to_optimize_source_code: str | None = "",
|
function_to_optimize_source_code: str | None = "",
|
||||||
|
function_benchmark_timings: dict[str, dict[str, float]] | None = None,
|
||||||
|
total_benchmark_timings: dict[str, float] | None = None,
|
||||||
) -> FunctionOptimizer:
|
) -> FunctionOptimizer:
|
||||||
return FunctionOptimizer(
|
return FunctionOptimizer(
|
||||||
function_to_optimize=function_to_optimize,
|
function_to_optimize=function_to_optimize,
|
||||||
|
|
@ -59,6 +64,8 @@ class Optimizer:
|
||||||
function_to_optimize_ast=function_to_optimize_ast,
|
function_to_optimize_ast=function_to_optimize_ast,
|
||||||
aiservice_client=self.aiservice_client,
|
aiservice_client=self.aiservice_client,
|
||||||
args=self.args,
|
args=self.args,
|
||||||
|
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
|
||||||
|
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
|
|
@ -80,6 +87,23 @@ class Optimizer:
|
||||||
project_root=self.args.project_root,
|
project_root=self.args.project_root,
|
||||||
module_root=self.args.module_root,
|
module_root=self.args.module_root,
|
||||||
)
|
)
|
||||||
|
if self.args.benchmark:
|
||||||
|
all_functions_to_optimize = [
|
||||||
|
function
|
||||||
|
for functions_list in file_to_funcs_to_optimize.values()
|
||||||
|
for function in functions_list
|
||||||
|
]
|
||||||
|
logger.info(f"Tracing existing benchmarks for {len(all_functions_to_optimize)} functions")
|
||||||
|
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.project_root, [fto.qualified_name_with_file_name for fto in all_functions_to_optimize])
|
||||||
|
logger.info("Finished tracing existing benchmarks")
|
||||||
|
trace_dir = Path(self.args.benchmarks_root) / ".codeflash_trace"
|
||||||
|
function_benchmark_timings = get_function_benchmark_timings(trace_dir, all_functions_to_optimize)
|
||||||
|
print(function_benchmark_timings)
|
||||||
|
total_benchmark_timings = get_benchmark_timings(trace_dir)
|
||||||
|
print("Total benchmark timings:")
|
||||||
|
print(total_benchmark_timings)
|
||||||
|
# for function in fully_qualified_function_names:
|
||||||
|
|
||||||
|
|
||||||
optimizations_found: int = 0
|
optimizations_found: int = 0
|
||||||
function_iterator_count: int = 0
|
function_iterator_count: int = 0
|
||||||
|
|
@ -93,6 +117,8 @@ class Optimizer:
|
||||||
logger.info("No functions found to optimize. Exiting…")
|
logger.info("No functions found to optimize. Exiting…")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
console.rule()
|
||||||
|
logger.info(f"Discovering existing unit tests in {self.test_cfg.tests_root}…")
|
||||||
console.rule()
|
console.rule()
|
||||||
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
|
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
|
||||||
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
|
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
|
||||||
|
|
@ -136,7 +162,6 @@ class Optimizer:
|
||||||
validated_original_code[analysis.file_path] = ValidCode(
|
validated_original_code[analysis.file_path] = ValidCode(
|
||||||
source_code=callee_original_code, normalized_code=normalized_callee_original_code
|
source_code=callee_original_code, normalized_code=normalized_callee_original_code
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_syntax_error:
|
if has_syntax_error:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -146,7 +171,7 @@ class Optimizer:
|
||||||
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
|
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
|
||||||
f"{function_to_optimize.qualified_name}"
|
f"{function_to_optimize.qualified_name}"
|
||||||
)
|
)
|
||||||
console.rule()
|
|
||||||
if not (
|
if not (
|
||||||
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
|
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
|
||||||
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
|
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
|
||||||
|
|
@ -157,12 +182,17 @@ class Optimizer:
|
||||||
f"Skipping optimization."
|
f"Skipping optimization."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
function_optimizer = self.create_function_optimizer(
|
if self.args.benchmark:
|
||||||
function_to_optimize,
|
|
||||||
function_to_optimize_ast,
|
function_optimizer = self.create_function_optimizer(
|
||||||
function_to_tests,
|
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings, total_benchmark_timings
|
||||||
validated_original_code[original_module_path].source_code,
|
)
|
||||||
)
|
else:
|
||||||
|
function_optimizer = self.create_function_optimizer(
|
||||||
|
function_to_optimize, function_to_optimize_ast, function_to_tests,
|
||||||
|
validated_original_code[original_module_path].source_code
|
||||||
|
)
|
||||||
|
|
||||||
best_optimization = function_optimizer.optimize_function()
|
best_optimization = function_optimizer.optimize_function()
|
||||||
if is_successful(best_optimization):
|
if is_successful(best_optimization):
|
||||||
optimizations_found += 1
|
optimizations_found += 1
|
||||||
|
|
@ -191,6 +221,7 @@ class Optimizer:
|
||||||
get_run_tmp_file.tmpdir.cleanup()
|
get_run_tmp_file.tmpdir.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_with_args(args: Namespace) -> None:
|
def run_with_args(args: Namespace) -> None:
|
||||||
optimizer = Optimizer(args)
|
optimizer = Optimizer(args)
|
||||||
optimizer.run()
|
optimizer.run()
|
||||||
|
|
|
||||||
|
|
@ -18,21 +18,19 @@ import marshal
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from argparse import ArgumentParser
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from copy import copy
|
||||||
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
from types import FrameType
|
||||||
|
from typing import Any, ClassVar, List
|
||||||
|
|
||||||
import dill
|
import dill
|
||||||
import isort
|
import isort
|
||||||
from rich.align import Align
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
from rich.text import Text
|
|
||||||
|
|
||||||
from codeflash.cli_cmds.cli import project_root_from_module_root
|
from codeflash.cli_cmds.cli import project_root_from_module_root
|
||||||
from codeflash.cli_cmds.console import console
|
from codeflash.cli_cmds.console import console
|
||||||
|
|
@ -42,34 +40,14 @@ from codeflash.discovery.functions_to_optimize import filter_files_optimized
|
||||||
from codeflash.tracing.replay_test import create_trace_replay_test
|
from codeflash.tracing.replay_test import create_trace_replay_test
|
||||||
from codeflash.tracing.tracing_utils import FunctionModules
|
from codeflash.tracing.tracing_utils import FunctionModules
|
||||||
from codeflash.verification.verification_utils import get_test_file_path
|
from codeflash.verification.verification_utils import get_test_file_path
|
||||||
|
# import warnings
|
||||||
if TYPE_CHECKING:
|
# warnings.filterwarnings("ignore", category=dill.PickleWarning)
|
||||||
from types import FrameType, TracebackType
|
# warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
|
|
||||||
|
|
||||||
class FakeCode:
|
|
||||||
def __init__(self, filename: str, line: int, name: str) -> None:
|
|
||||||
self.co_filename = filename
|
|
||||||
self.co_line = line
|
|
||||||
self.co_name = name
|
|
||||||
self.co_firstlineno = 0
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return repr((self.co_filename, self.co_line, self.co_name, None))
|
|
||||||
|
|
||||||
|
|
||||||
class FakeFrame:
|
|
||||||
def __init__(self, code: FakeCode, prior: FakeFrame | None) -> None:
|
|
||||||
self.f_code = code
|
|
||||||
self.f_back = prior
|
|
||||||
self.f_locals: dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
|
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
|
||||||
class Tracer:
|
class Tracer:
|
||||||
"""Use this class as a 'with' context manager to trace a function call.
|
"""Use this class as a 'with' context manager to trace a function call,
|
||||||
|
input arguments, and profiling info.
|
||||||
Traces function calls, input arguments, and profiling info.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -81,9 +59,7 @@ class Tracer:
|
||||||
max_function_count: int = 256,
|
max_function_count: int = 256,
|
||||||
timeout: int | None = None, # seconds
|
timeout: int | None = None, # seconds
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Use this class to trace function calls.
|
""":param output: The path to the output trace file
|
||||||
|
|
||||||
:param output: The path to the output trace file
|
|
||||||
:param functions: List of functions to trace. If None, trace all functions
|
:param functions: List of functions to trace. If None, trace all functions
|
||||||
:param disable: Disable the tracer if True
|
:param disable: Disable the tracer if True
|
||||||
:param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered
|
:param config_file_path: Path to the pyproject.toml file, if None then it will be auto-discovered
|
||||||
|
|
@ -94,9 +70,7 @@ class Tracer:
|
||||||
if functions is None:
|
if functions is None:
|
||||||
functions = []
|
functions = []
|
||||||
if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1":
|
if os.environ.get("CODEFLASH_TRACER_DISABLE", "0") == "1":
|
||||||
console.rule(
|
console.print("Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE")
|
||||||
"Codeflash: Tracer disabled by environment variable CODEFLASH_TRACER_DISABLE", style="bold red"
|
|
||||||
)
|
|
||||||
disable = True
|
disable = True
|
||||||
self.disable = disable
|
self.disable = disable
|
||||||
if self.disable:
|
if self.disable:
|
||||||
|
|
@ -111,7 +85,7 @@ class Tracer:
|
||||||
self.con = None
|
self.con = None
|
||||||
self.output_file = Path(output).resolve()
|
self.output_file = Path(output).resolve()
|
||||||
self.functions = functions
|
self.functions = functions
|
||||||
self.function_modules: list[FunctionModules] = []
|
self.function_modules: List[FunctionModules] = []
|
||||||
self.function_count = defaultdict(int)
|
self.function_count = defaultdict(int)
|
||||||
self.current_file_path = Path(__file__).resolve()
|
self.current_file_path = Path(__file__).resolve()
|
||||||
self.ignored_qualified_functions = {
|
self.ignored_qualified_functions = {
|
||||||
|
|
@ -121,10 +95,10 @@ class Tracer:
|
||||||
self.max_function_count = max_function_count
|
self.max_function_count = max_function_count
|
||||||
self.config, found_config_path = parse_config_file(config_file_path)
|
self.config, found_config_path = parse_config_file(config_file_path)
|
||||||
self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path)
|
self.project_root = project_root_from_module_root(Path(self.config["module_root"]), found_config_path)
|
||||||
console.rule(f"Project Root: {self.project_root}", style="bold blue")
|
print("project_root", self.project_root)
|
||||||
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
|
self.ignored_functions = {"<listcomp>", "<genexpr>", "<dictcomp>", "<setcomp>", "<lambda>", "<module>"}
|
||||||
|
|
||||||
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_") # noqa: SLF001
|
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(".", "_")
|
||||||
|
|
||||||
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
|
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
@ -145,44 +119,48 @@ class Tracer:
|
||||||
def __enter__(self) -> None:
|
def __enter__(self) -> None:
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
if getattr(Tracer, "used_once", False):
|
|
||||||
console.print(
|
# if getattr(Tracer, "used_once", False):
|
||||||
"Codeflash: Tracer can only be used once per program run. "
|
# console.print(
|
||||||
"Please only enable the Tracer once. Skipping tracing this section."
|
# "Codeflash: Tracer can only be used once per program run. "
|
||||||
)
|
# "Please only enable the Tracer once. Skipping tracing this section."
|
||||||
self.disable = True
|
# )
|
||||||
return
|
# self.disable = True
|
||||||
Tracer.used_once = True
|
# return
|
||||||
|
# Tracer.used_once = True
|
||||||
|
|
||||||
if pathlib.Path(self.output_file).exists():
|
if pathlib.Path(self.output_file).exists():
|
||||||
console.rule("Removing existing trace file", style="bold red")
|
console.print("Codeflash: Removing existing trace file")
|
||||||
console.rule()
|
|
||||||
pathlib.Path(self.output_file).unlink(missing_ok=True)
|
pathlib.Path(self.output_file).unlink(missing_ok=True)
|
||||||
|
|
||||||
self.con = sqlite3.connect(self.output_file, check_same_thread=False)
|
self.con = sqlite3.connect(self.output_file)
|
||||||
cur = self.con.cursor()
|
cur = self.con.cursor()
|
||||||
cur.execute("""PRAGMA synchronous = OFF""")
|
cur.execute("""PRAGMA synchronous = OFF""")
|
||||||
cur.execute("""PRAGMA journal_mode = WAL""")
|
|
||||||
# TODO: Check out if we need to export the function test name as well
|
# TODO: Check out if we need to export the function test name as well
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
|
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
|
||||||
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
|
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
|
||||||
)
|
)
|
||||||
console.rule("Codeflash: Traced Program Output Begin", style="bold blue")
|
console.print("Codeflash: Tracing started!")
|
||||||
frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001
|
frame = sys._getframe(0) # Get this frame and simulate a call to it
|
||||||
self.dispatch["call"](self, frame, 0)
|
self.dispatch["call"](self, frame, 0)
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
sys.setprofile(self.trace_callback)
|
sys.setprofile(self.trace_callback)
|
||||||
threading.setprofile(self.trace_callback)
|
|
||||||
|
|
||||||
def __exit__(
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
|
||||||
) -> None:
|
|
||||||
if self.disable:
|
if self.disable:
|
||||||
return
|
return
|
||||||
sys.setprofile(None)
|
sys.setprofile(None)
|
||||||
self.con.commit()
|
self.con.commit()
|
||||||
console.rule("Codeflash: Traced Program Output End", style="bold blue")
|
# Check if any functions were actually traced
|
||||||
|
if self.trace_count == 0:
|
||||||
|
self.con.close()
|
||||||
|
# Delete the trace file if no functions were traced
|
||||||
|
if self.output_file.exists():
|
||||||
|
self.output_file.unlink()
|
||||||
|
console.print("Codeflash: No functions were traced. Removing trace database.")
|
||||||
|
return
|
||||||
|
|
||||||
self.create_stats()
|
self.create_stats()
|
||||||
|
|
||||||
cur = self.con.cursor()
|
cur = self.con.cursor()
|
||||||
|
|
@ -226,13 +204,14 @@ class Tracer:
|
||||||
test_framework=self.config["test_framework"],
|
test_framework=self.config["test_framework"],
|
||||||
max_run_count=self.max_function_count,
|
max_run_count=self.max_function_count,
|
||||||
)
|
)
|
||||||
function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
|
# Need a better way to store the replay test
|
||||||
|
# function_path = "_".join(self.functions) if self.functions else self.file_being_called_from
|
||||||
|
function_path = self.file_being_called_from
|
||||||
test_file_path = get_test_file_path(
|
test_file_path = get_test_file_path(
|
||||||
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
|
test_dir=Path(self.config["tests_root"]), function_name=function_path, test_type="replay"
|
||||||
)
|
)
|
||||||
replay_test = isort.code(replay_test)
|
replay_test = isort.code(replay_test)
|
||||||
|
with open(test_file_path, "w", encoding="utf8") as file:
|
||||||
with Path(test_file_path).open("w", encoding="utf8") as file:
|
|
||||||
file.write(replay_test)
|
file.write(replay_test)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
|
|
@ -242,27 +221,25 @@ class Tracer:
|
||||||
overflow="ignore",
|
overflow="ignore",
|
||||||
)
|
)
|
||||||
|
|
||||||
def tracer_logic(self, frame: FrameType, event: str) -> None:
|
def tracer_logic(self, frame: FrameType, event: str):
|
||||||
if event != "call":
|
if event != "call":
|
||||||
return
|
return
|
||||||
if self.timeout is not None and (time.time() - self.start_time) > self.timeout:
|
if self.timeout is not None:
|
||||||
sys.setprofile(None)
|
if (time.time() - self.start_time) > self.timeout:
|
||||||
threading.setprofile(None)
|
sys.setprofile(None)
|
||||||
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
|
console.print(f"Codeflash: Timeout reached! Stopping tracing at {self.timeout} seconds.")
|
||||||
return
|
return
|
||||||
code = frame.f_code
|
code = frame.f_code
|
||||||
|
|
||||||
file_name = Path(code.co_filename).resolve()
|
file_name = Path(code.co_filename).resolve()
|
||||||
# TODO : It currently doesn't log the last return call from the first function
|
# TODO : It currently doesn't log the last return call from the first function
|
||||||
|
|
||||||
if code.co_name in self.ignored_functions:
|
if code.co_name in self.ignored_functions:
|
||||||
return
|
return
|
||||||
if not file_name.is_relative_to(self.project_root):
|
|
||||||
return
|
|
||||||
if not file_name.exists():
|
if not file_name.exists():
|
||||||
return
|
return
|
||||||
if self.functions and code.co_name not in self.functions:
|
# if self.functions:
|
||||||
return
|
# if code.co_name not in self.functions:
|
||||||
|
# return
|
||||||
class_name = None
|
class_name = None
|
||||||
arguments = frame.f_locals
|
arguments = frame.f_locals
|
||||||
try:
|
try:
|
||||||
|
|
@ -274,12 +251,16 @@ class Tracer:
|
||||||
class_name = arguments["self"].__class__.__name__
|
class_name = arguments["self"].__class__.__name__
|
||||||
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
||||||
class_name = arguments["cls"].__name__
|
class_name = arguments["cls"].__name__
|
||||||
except: # noqa: E722
|
except:
|
||||||
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
|
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
|
||||||
return
|
return
|
||||||
|
|
||||||
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
|
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
|
||||||
if function_qualified_name in self.ignored_qualified_functions:
|
if function_qualified_name in self.ignored_qualified_functions:
|
||||||
return
|
return
|
||||||
|
if self.functions and function_qualified_name not in self.functions:
|
||||||
|
return
|
||||||
|
|
||||||
if function_qualified_name not in self.function_count:
|
if function_qualified_name not in self.function_count:
|
||||||
# seeing this function for the first time
|
# seeing this function for the first time
|
||||||
self.function_count[function_qualified_name] = 0
|
self.function_count[function_qualified_name] = 0
|
||||||
|
|
@ -354,14 +335,17 @@ class Tracer:
|
||||||
self.next_insert = 1000
|
self.next_insert = 1000
|
||||||
self.con.commit()
|
self.con.commit()
|
||||||
|
|
||||||
def trace_callback(self, frame: FrameType, event: str, arg: str | None) -> None:
|
def trace_callback(self, frame: FrameType, event: str, arg: Any) -> None:
|
||||||
# profiler section
|
# profiler section
|
||||||
timer = self.timer
|
timer = self.timer
|
||||||
t = timer() - self.t - self.bias
|
t = timer() - self.t - self.bias
|
||||||
if event == "c_call":
|
if event == "c_call":
|
||||||
self.c_func_name = arg.__name__
|
self.c_func_name = arg.__name__
|
||||||
|
|
||||||
prof_success = bool(self.dispatch[event](self, frame, t))
|
if self.dispatch[event](self, frame, t):
|
||||||
|
prof_success = True
|
||||||
|
else:
|
||||||
|
prof_success = False
|
||||||
# tracer section
|
# tracer section
|
||||||
self.tracer_logic(frame, event)
|
self.tracer_logic(frame, event)
|
||||||
# measure the time as the last thing before return
|
# measure the time as the last thing before return
|
||||||
|
|
@ -370,60 +354,45 @@ class Tracer:
|
||||||
else:
|
else:
|
||||||
self.t = timer() - t # put back unrecorded delta
|
self.t = timer() - t # put back unrecorded delta
|
||||||
|
|
||||||
def trace_dispatch_call(self, frame: FrameType, t: int) -> int:
|
def trace_dispatch_call(self, frame, t):
|
||||||
"""Handle call events in the profiler."""
|
if self.cur and frame.f_back is not self.cur[-2]:
|
||||||
|
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||||
|
if not isinstance(rframe, Tracer.fake_frame):
|
||||||
|
assert rframe.f_back is frame.f_back, ("Bad call", rfn, rframe, rframe.f_back, frame, frame.f_back)
|
||||||
|
self.trace_dispatch_return(rframe, 0)
|
||||||
|
assert self.cur is None or frame.f_back is self.cur[-2], ("Bad call", self.cur[-3])
|
||||||
|
fcode = frame.f_code
|
||||||
|
arguments = frame.f_locals
|
||||||
|
class_name = None
|
||||||
try:
|
try:
|
||||||
# In multi-threaded contexts, we need to be more careful about frame comparisons
|
if (
|
||||||
if self.cur and frame.f_back is not self.cur[-2]:
|
"self" in arguments
|
||||||
# This happens when we're in a different thread
|
and hasattr(arguments["self"], "__class__")
|
||||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
and hasattr(arguments["self"].__class__, "__name__")
|
||||||
|
):
|
||||||
|
class_name = arguments["self"].__class__.__name__
|
||||||
|
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
||||||
|
class_name = arguments["cls"].__name__
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
|
||||||
|
self.cur = (t, 0, 0, fn, frame, self.cur)
|
||||||
|
timings = self.timings
|
||||||
|
if fn in timings:
|
||||||
|
cc, ns, tt, ct, callers = timings[fn]
|
||||||
|
timings[fn] = cc, ns + 1, tt, ct, callers
|
||||||
|
else:
|
||||||
|
timings[fn] = 0, 0, 0, 0, {}
|
||||||
|
return 1
|
||||||
|
|
||||||
# Only attempt to handle the frame mismatch if we have a valid rframe
|
def trace_dispatch_exception(self, frame, t):
|
||||||
if (
|
|
||||||
not isinstance(rframe, FakeFrame)
|
|
||||||
and hasattr(rframe, "f_back")
|
|
||||||
and hasattr(frame, "f_back")
|
|
||||||
and rframe.f_back is frame.f_back
|
|
||||||
):
|
|
||||||
self.trace_dispatch_return(rframe, 0)
|
|
||||||
|
|
||||||
# Get function information
|
|
||||||
fcode = frame.f_code
|
|
||||||
arguments = frame.f_locals
|
|
||||||
class_name = None
|
|
||||||
try:
|
|
||||||
if (
|
|
||||||
"self" in arguments
|
|
||||||
and hasattr(arguments["self"], "__class__")
|
|
||||||
and hasattr(arguments["self"].__class__, "__name__")
|
|
||||||
):
|
|
||||||
class_name = arguments["self"].__class__.__name__
|
|
||||||
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
|
||||||
class_name = arguments["cls"].__name__
|
|
||||||
except Exception: # noqa: BLE001, S110
|
|
||||||
pass
|
|
||||||
|
|
||||||
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
|
|
||||||
self.cur = (t, 0, 0, fn, frame, self.cur)
|
|
||||||
timings = self.timings
|
|
||||||
if fn in timings:
|
|
||||||
cc, ns, tt, ct, callers = timings[fn]
|
|
||||||
timings[fn] = cc, ns + 1, tt, ct, callers
|
|
||||||
else:
|
|
||||||
timings[fn] = 0, 0, 0, 0, {}
|
|
||||||
return 1 # noqa: TRY300
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
# Handle any errors gracefully
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def trace_dispatch_exception(self, frame: FrameType, t: int) -> int:
|
|
||||||
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
rpt, rit, ret, rfn, rframe, rcur = self.cur
|
||||||
if (rframe is not frame) and rcur:
|
if (rframe is not frame) and rcur:
|
||||||
return self.trace_dispatch_return(rframe, t)
|
return self.trace_dispatch_return(rframe, t)
|
||||||
self.cur = rpt, rit + t, ret, rfn, rframe, rcur
|
self.cur = rpt, rit + t, ret, rfn, rframe, rcur
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def trace_dispatch_c_call(self, frame: FrameType, t: int) -> int:
|
def trace_dispatch_c_call(self, frame, t):
|
||||||
fn = ("", 0, self.c_func_name, None)
|
fn = ("", 0, self.c_func_name, None)
|
||||||
self.cur = (t, 0, 0, fn, frame, self.cur)
|
self.cur = (t, 0, 0, fn, frame, self.cur)
|
||||||
timings = self.timings
|
timings = self.timings
|
||||||
|
|
@ -434,27 +403,15 @@ class Tracer:
|
||||||
timings[fn] = 0, 0, 0, 0, {}
|
timings[fn] = 0, 0, 0, 0, {}
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def trace_dispatch_return(self, frame: FrameType, t: int) -> int:
|
def trace_dispatch_return(self, frame, t):
|
||||||
if not self.cur or not self.cur[-2]:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# In multi-threaded environments, frames can get mismatched
|
|
||||||
if frame is not self.cur[-2]:
|
if frame is not self.cur[-2]:
|
||||||
# Don't assert in threaded environments - frames can legitimately differ
|
assert frame is self.cur[-2].f_back, ("Bad return", self.cur[-3])
|
||||||
if hasattr(frame, "f_back") and hasattr(self.cur[-2], "f_back") and frame is self.cur[-2].f_back:
|
self.trace_dispatch_return(self.cur[-2], 0)
|
||||||
self.trace_dispatch_return(self.cur[-2], 0)
|
|
||||||
else:
|
|
||||||
# We're in a different thread or context, can't continue with this frame
|
|
||||||
return 0
|
|
||||||
# Prefix "r" means part of the Returning or exiting frame.
|
# Prefix "r" means part of the Returning or exiting frame.
|
||||||
# Prefix "p" means part of the Previous or Parent or older frame.
|
# Prefix "p" means part of the Previous or Parent or older frame.
|
||||||
|
|
||||||
rpt, rit, ret, rfn, frame, rcur = self.cur
|
rpt, rit, ret, rfn, frame, rcur = self.cur
|
||||||
|
|
||||||
# Guard against invalid rcur (w threading)
|
|
||||||
if not rcur:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
rit = rit + t
|
rit = rit + t
|
||||||
frame_total = rit + ret
|
frame_total = rit + ret
|
||||||
|
|
||||||
|
|
@ -462,9 +419,6 @@ class Tracer:
|
||||||
self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur
|
self.cur = ppt, pit + rpt, pet + frame_total, pfn, pframe, pcur
|
||||||
|
|
||||||
timings = self.timings
|
timings = self.timings
|
||||||
if rfn not in timings:
|
|
||||||
# w threading, rfn can be missing
|
|
||||||
timings[rfn] = 0, 0, 0, 0, {}
|
|
||||||
cc, ns, tt, ct, callers = timings[rfn]
|
cc, ns, tt, ct, callers = timings[rfn]
|
||||||
if not ns:
|
if not ns:
|
||||||
# This is the only occurrence of the function on the stack.
|
# This is the only occurrence of the function on the stack.
|
||||||
|
|
@ -486,7 +440,7 @@ class Tracer:
|
||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
dispatch: ClassVar[dict[str, Callable[[Tracer, FrameType, int], int]]] = {
|
dispatch: ClassVar[dict[str, callable]] = {
|
||||||
"call": trace_dispatch_call,
|
"call": trace_dispatch_call,
|
||||||
"exception": trace_dispatch_exception,
|
"exception": trace_dispatch_exception,
|
||||||
"return": trace_dispatch_return,
|
"return": trace_dispatch_return,
|
||||||
|
|
@ -495,13 +449,32 @@ class Tracer:
|
||||||
"c_return": trace_dispatch_return,
|
"c_return": trace_dispatch_return,
|
||||||
}
|
}
|
||||||
|
|
||||||
def simulate_call(self, name: str) -> None:
|
class fake_code:
|
||||||
code = FakeCode("profiler", 0, name)
|
def __init__(self, filename, line, name):
|
||||||
pframe = self.cur[-2] if self.cur else None
|
self.co_filename = filename
|
||||||
frame = FakeFrame(code, pframe)
|
self.co_line = line
|
||||||
|
self.co_name = name
|
||||||
|
self.co_firstlineno = 0
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return repr((self.co_filename, self.co_line, self.co_name, None))
|
||||||
|
|
||||||
|
class fake_frame:
|
||||||
|
def __init__(self, code, prior):
|
||||||
|
self.f_code = code
|
||||||
|
self.f_back = prior
|
||||||
|
self.f_locals = {}
|
||||||
|
|
||||||
|
def simulate_call(self, name):
|
||||||
|
code = self.fake_code("profiler", 0, name)
|
||||||
|
if self.cur:
|
||||||
|
pframe = self.cur[-2]
|
||||||
|
else:
|
||||||
|
pframe = None
|
||||||
|
frame = self.fake_frame(code, pframe)
|
||||||
self.dispatch["call"](self, frame, 0)
|
self.dispatch["call"](self, frame, 0)
|
||||||
|
|
||||||
def simulate_cmd_complete(self) -> None:
|
def simulate_cmd_complete(self):
|
||||||
get_time = self.timer
|
get_time = self.timer
|
||||||
t = get_time() - self.t
|
t = get_time() - self.t
|
||||||
while self.cur[-1]:
|
while self.cur[-1]:
|
||||||
|
|
@ -511,174 +484,60 @@ class Tracer:
|
||||||
t = 0
|
t = 0
|
||||||
self.t = get_time() - t
|
self.t = get_time() - t
|
||||||
|
|
||||||
def print_stats(self, sort: str | int | tuple = -1) -> None:
|
def print_stats(self, sort=-1):
|
||||||
if not self.stats:
|
import pstats
|
||||||
console.print("Codeflash: No stats available to print")
|
|
||||||
self.total_tt = 0
|
|
||||||
return
|
|
||||||
|
|
||||||
if not isinstance(sort, tuple):
|
if not isinstance(sort, tuple):
|
||||||
sort = (sort,)
|
sort = (sort,)
|
||||||
|
# The following code customizes the default printing behavior to
|
||||||
|
# print in milliseconds.
|
||||||
|
s = StringIO()
|
||||||
|
stats_obj = pstats.Stats(copy(self), stream=s)
|
||||||
|
stats_obj.strip_dirs().sort_stats(*sort).print_stats(100)
|
||||||
|
self.total_tt = stats_obj.total_tt
|
||||||
|
console.print("total_tt", self.total_tt)
|
||||||
|
raw_stats = s.getvalue()
|
||||||
|
m = re.search(r"function calls?.*in (\d+)\.\d+ (seconds?)", raw_stats)
|
||||||
|
total_time = None
|
||||||
|
if m:
|
||||||
|
total_time = int(m.group(1))
|
||||||
|
if total_time is None:
|
||||||
|
console.print("Failed to get total time from stats")
|
||||||
|
total_time_ms = total_time / 1e6
|
||||||
|
raw_stats = re.sub(
|
||||||
|
r"(function calls?.*)in (\d+)\.\d+ (seconds?)", rf"\1 in {total_time_ms:.3f} milliseconds", raw_stats
|
||||||
|
)
|
||||||
|
match_pattern = r"^ *[\d\/]+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +"
|
||||||
|
m = re.findall(match_pattern, raw_stats, re.MULTILINE)
|
||||||
|
ms_times = []
|
||||||
|
for tottime, percall, cumtime, percall_cum in m:
|
||||||
|
tottime_ms = int(tottime) / 1e6
|
||||||
|
percall_ms = int(percall) / 1e6
|
||||||
|
cumtime_ms = int(cumtime) / 1e6
|
||||||
|
percall_cum_ms = int(percall_cum) / 1e6
|
||||||
|
ms_times.append([tottime_ms, percall_ms, cumtime_ms, percall_cum_ms])
|
||||||
|
split_stats = raw_stats.split("\n")
|
||||||
|
new_stats = []
|
||||||
|
|
||||||
# First, convert stats to make them pstats-compatible
|
replace_pattern = r"^( *[\d\/]+) +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(\d+)\.\d+ +(.*)"
|
||||||
try:
|
times_index = 0
|
||||||
# Initialize empty collections for pstats
|
for line in split_stats:
|
||||||
self.files = []
|
if times_index >= len(ms_times):
|
||||||
self.top_level = []
|
replaced = line
|
||||||
|
else:
|
||||||
# Create entirely new dictionaries instead of modifying existing ones
|
replaced, n = re.subn(
|
||||||
new_stats = {}
|
replace_pattern,
|
||||||
new_timings = {}
|
rf"\g<1>{ms_times[times_index][0]:8.3f} {ms_times[times_index][1]:8.3f} {ms_times[times_index][2]:8.3f} {ms_times[times_index][3]:8.3f} \g<6>",
|
||||||
|
line,
|
||||||
# Convert stats dictionary
|
count=1,
|
||||||
stats_items = list(self.stats.items())
|
|
||||||
for func, stats_data in stats_items:
|
|
||||||
try:
|
|
||||||
# Make sure we have 5 elements in stats_data
|
|
||||||
if len(stats_data) != 5:
|
|
||||||
console.print(f"Skipping malformed stats data for {func}: {stats_data}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
cc, nc, tt, ct, callers = stats_data
|
|
||||||
|
|
||||||
if len(func) == 4:
|
|
||||||
file_name, line_num, func_name, class_name = func
|
|
||||||
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
|
|
||||||
new_func = (file_name, line_num, new_func_name)
|
|
||||||
else:
|
|
||||||
new_func = func # Keep as is if already in correct format
|
|
||||||
|
|
||||||
new_callers = {}
|
|
||||||
callers_items = list(callers.items())
|
|
||||||
for caller_func, count in callers_items:
|
|
||||||
if isinstance(caller_func, tuple):
|
|
||||||
if len(caller_func) == 4:
|
|
||||||
caller_file, caller_line, caller_name, caller_class = caller_func
|
|
||||||
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
|
||||||
new_caller_func = (caller_file, caller_line, caller_new_name)
|
|
||||||
else:
|
|
||||||
new_caller_func = caller_func
|
|
||||||
else:
|
|
||||||
console.print(f"Unexpected caller format: {caller_func}")
|
|
||||||
new_caller_func = str(caller_func)
|
|
||||||
|
|
||||||
new_callers[new_caller_func] = count
|
|
||||||
|
|
||||||
# Store with new format
|
|
||||||
new_stats[new_func] = (cc, nc, tt, ct, new_callers)
|
|
||||||
except Exception as e: # noqa: BLE001
|
|
||||||
console.print(f"Error converting stats for {func}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
timings_items = list(self.timings.items())
|
|
||||||
for func, timing_data in timings_items:
|
|
||||||
try:
|
|
||||||
if len(timing_data) != 5:
|
|
||||||
console.print(f"Skipping malformed timing data for {func}: {timing_data}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
cc, ns, tt, ct, callers = timing_data
|
|
||||||
|
|
||||||
if len(func) == 4:
|
|
||||||
file_name, line_num, func_name, class_name = func
|
|
||||||
new_func_name = f"{class_name}.{func_name}" if class_name else func_name
|
|
||||||
new_func = (file_name, line_num, new_func_name)
|
|
||||||
else:
|
|
||||||
new_func = func
|
|
||||||
|
|
||||||
new_callers = {}
|
|
||||||
callers_items = list(callers.items())
|
|
||||||
for caller_func, count in callers_items:
|
|
||||||
if isinstance(caller_func, tuple):
|
|
||||||
if len(caller_func) == 4:
|
|
||||||
caller_file, caller_line, caller_name, caller_class = caller_func
|
|
||||||
caller_new_name = f"{caller_class}.{caller_name}" if caller_class else caller_name
|
|
||||||
new_caller_func = (caller_file, caller_line, caller_new_name)
|
|
||||||
else:
|
|
||||||
new_caller_func = caller_func
|
|
||||||
else:
|
|
||||||
console.print(f"Unexpected caller format: {caller_func}")
|
|
||||||
new_caller_func = str(caller_func)
|
|
||||||
|
|
||||||
new_callers[new_caller_func] = count
|
|
||||||
|
|
||||||
new_timings[new_func] = (cc, ns, tt, ct, new_callers)
|
|
||||||
except Exception as e: # noqa: BLE001
|
|
||||||
console.print(f"Error converting timings for {func}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.stats = new_stats
|
|
||||||
self.timings = new_timings
|
|
||||||
|
|
||||||
self.total_tt = sum(tt for _, _, tt, _, _ in self.stats.values())
|
|
||||||
|
|
||||||
total_calls = sum(cc for cc, _, _, _, _ in self.stats.values())
|
|
||||||
total_primitive = sum(nc for _, nc, _, _, _ in self.stats.values())
|
|
||||||
|
|
||||||
summary = Text.assemble(
|
|
||||||
f"{total_calls:,} function calls ",
|
|
||||||
("(" + f"{total_primitive:,} primitive calls" + ")", "dim"),
|
|
||||||
f" in {self.total_tt / 1e6:.3f}milliseconds",
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(Align.center(Panel(summary, border_style="blue", width=80, padding=(0, 2), expand=False)))
|
|
||||||
|
|
||||||
table = Table(
|
|
||||||
show_header=True,
|
|
||||||
header_style="bold magenta",
|
|
||||||
border_style="blue",
|
|
||||||
title="[bold]Function Profile[/bold] (ordered by internal time)",
|
|
||||||
title_style="cyan",
|
|
||||||
caption=f"Showing top 25 of {len(self.stats)} functions",
|
|
||||||
)
|
|
||||||
|
|
||||||
table.add_column("Calls", justify="right", style="green", width=10)
|
|
||||||
table.add_column("Time (ms)", justify="right", style="cyan", width=10)
|
|
||||||
table.add_column("Per Call", justify="right", style="cyan", width=10)
|
|
||||||
table.add_column("Cum (ms)", justify="right", style="yellow", width=10)
|
|
||||||
table.add_column("Cum/Call", justify="right", style="yellow", width=10)
|
|
||||||
table.add_column("Function", style="blue")
|
|
||||||
|
|
||||||
sorted_stats = sorted(
|
|
||||||
((func, stats) for func, stats in self.stats.items() if isinstance(func, tuple) and len(func) == 3),
|
|
||||||
key=lambda x: x[1][2], # Sort by tt (internal time)
|
|
||||||
reverse=True,
|
|
||||||
)[:25] # Limit to top 25
|
|
||||||
|
|
||||||
# Format and add each row to the table
|
|
||||||
for func, (cc, nc, tt, ct, _) in sorted_stats:
|
|
||||||
filename, lineno, funcname = func
|
|
||||||
|
|
||||||
# Format calls - show recursive format if different
|
|
||||||
calls_str = f"{cc}/{nc}" if cc != nc else f"{cc:,}"
|
|
||||||
|
|
||||||
# Convert to milliseconds
|
|
||||||
tt_ms = tt / 1e6
|
|
||||||
ct_ms = ct / 1e6
|
|
||||||
|
|
||||||
# Calculate per-call times
|
|
||||||
per_call = tt_ms / cc if cc > 0 else 0
|
|
||||||
cum_per_call = ct_ms / nc if nc > 0 else 0
|
|
||||||
base_filename = Path(filename).name
|
|
||||||
file_link = f"[link=file://{filename}]{base_filename}[/link]"
|
|
||||||
|
|
||||||
table.add_row(
|
|
||||||
calls_str,
|
|
||||||
f"{tt_ms:.3f}",
|
|
||||||
f"{per_call:.3f}",
|
|
||||||
f"{ct_ms:.3f}",
|
|
||||||
f"{cum_per_call:.3f}",
|
|
||||||
f"{funcname} [dim]({file_link}:{lineno})[/dim]",
|
|
||||||
)
|
)
|
||||||
|
if n > 0:
|
||||||
|
times_index += 1
|
||||||
|
new_stats.append(replaced)
|
||||||
|
|
||||||
console.print(Align.center(table))
|
console.print("\n".join(new_stats))
|
||||||
|
|
||||||
except Exception as e: # noqa: BLE001
|
def make_pstats_compatible(self):
|
||||||
console.print(f"[bold red]Error in stats processing:[/bold red] {e}")
|
|
||||||
console.print(f"Traced {self.trace_count:,} function calls")
|
|
||||||
self.total_tt = 0
|
|
||||||
|
|
||||||
def make_pstats_compatible(self) -> None:
|
|
||||||
# delete the extra class_name item from the function tuple
|
# delete the extra class_name item from the function tuple
|
||||||
self.files = []
|
self.files = []
|
||||||
self.top_level = []
|
self.top_level = []
|
||||||
|
|
@ -693,33 +552,36 @@ class Tracer:
|
||||||
self.stats = new_stats
|
self.stats = new_stats
|
||||||
self.timings = new_timings
|
self.timings = new_timings
|
||||||
|
|
||||||
def dump_stats(self, file: str) -> None:
|
def dump_stats(self, file):
|
||||||
with Path(file).open("wb") as f:
|
with open(file, "wb") as f:
|
||||||
|
self.create_stats()
|
||||||
marshal.dump(self.stats, f)
|
marshal.dump(self.stats, f)
|
||||||
|
|
||||||
def create_stats(self) -> None:
|
def create_stats(self):
|
||||||
self.simulate_cmd_complete()
|
self.simulate_cmd_complete()
|
||||||
self.snapshot_stats()
|
self.snapshot_stats()
|
||||||
|
|
||||||
def snapshot_stats(self) -> None:
|
def snapshot_stats(self):
|
||||||
self.stats = {}
|
self.stats = {}
|
||||||
for func, (cc, _ns, tt, ct, caller_dict) in self.timings.items():
|
for func, (cc, ns, tt, ct, callers) in self.timings.items():
|
||||||
callers = caller_dict.copy()
|
callers = callers.copy()
|
||||||
nc = 0
|
nc = 0
|
||||||
for callcnt in callers.values():
|
for callcnt in callers.values():
|
||||||
nc += callcnt
|
nc += callcnt
|
||||||
self.stats[func] = cc, nc, tt, ct, callers
|
self.stats[func] = cc, nc, tt, ct, callers
|
||||||
|
|
||||||
def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, Any]) -> Tracer | None:
|
def runctx(self, cmd, globals, locals):
|
||||||
self.__enter__()
|
self.__enter__()
|
||||||
try:
|
try:
|
||||||
exec(cmd, global_vars, local_vars) # noqa: S102
|
exec(cmd, globals, locals)
|
||||||
finally:
|
finally:
|
||||||
self.__exit__(None, None, None)
|
self.__exit__(None, None, None)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def main() -> ArgumentParser:
|
def main():
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
parser = ArgumentParser(allow_abbrev=False)
|
parser = ArgumentParser(allow_abbrev=False)
|
||||||
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
|
parser.add_argument("-o", "--outfile", dest="outfile", help="Save trace to <outfile>", required=True)
|
||||||
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
|
parser.add_argument("--only-functions", help="Trace only these functions", nargs="+", default=None)
|
||||||
|
|
@ -776,13 +638,16 @@ def main() -> ArgumentParser:
|
||||||
"__cached__": None,
|
"__cached__": None,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
Tracer(
|
tracer = Tracer(
|
||||||
output=args.outfile,
|
output=args.outfile,
|
||||||
functions=args.only_functions,
|
functions=args.only_functions,
|
||||||
max_function_count=args.max_function_count,
|
max_function_count=args.max_function_count,
|
||||||
timeout=args.tracer_timeout,
|
timeout=args.tracer_timeout,
|
||||||
config_file_path=args.codeflash_config,
|
config_file_path=args.codeflash_config,
|
||||||
).runctx(code, globs, None)
|
)
|
||||||
|
|
||||||
|
tracer.runctx(code, globs, None)
|
||||||
|
print(tracer.functions)
|
||||||
|
|
||||||
except BrokenPipeError as exc:
|
except BrokenPipeError as exc:
|
||||||
# Prevent "Exception ignored" during interpreter shutdown.
|
# Prevent "Exception ignored" during interpreter shutdown.
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ class TestType(Enum):
|
||||||
REPLAY_TEST = 4
|
REPLAY_TEST = 4
|
||||||
CONCOLIC_COVERAGE_TEST = 5
|
CONCOLIC_COVERAGE_TEST = 5
|
||||||
INIT_STATE_TEST = 6
|
INIT_STATE_TEST = 6
|
||||||
BENCHMARK_TEST = 7
|
|
||||||
|
|
||||||
def to_name(self) -> str:
|
def to_name(self) -> str:
|
||||||
if self == TestType.INIT_STATE_TEST:
|
if self == TestType.INIT_STATE_TEST:
|
||||||
|
|
@ -40,7 +39,6 @@ class TestType(Enum):
|
||||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||||
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
||||||
TestType.BENCHMARK_TEST: "📏 Benchmark Tests",
|
|
||||||
}
|
}
|
||||||
return names[self]
|
return names[self]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -75,3 +75,4 @@ class TestConfig:
|
||||||
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
|
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
|
||||||
concolic_test_root_dir: Optional[Path] = None
|
concolic_test_root_dir: Optional[Path] = None
|
||||||
pytest_cmd: str = "pytest"
|
pytest_cmd: str = "pytest"
|
||||||
|
benchmark_tests_root: Optional[Path] = None
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ exclude = [
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9"
|
python = ">=3.9"
|
||||||
unidiff = ">=0.7.4"
|
unidiff = ">=0.7.4"
|
||||||
pytest = ">=7.0.0,<8.3.4"
|
pytest = ">=7.0.0"
|
||||||
gitpython = ">=3.1.31"
|
gitpython = ">=3.1.31"
|
||||||
libcst = ">=1.0.1"
|
libcst = ">=1.0.1"
|
||||||
jedi = ">=0.19.1"
|
jedi = ">=0.19.1"
|
||||||
|
|
@ -92,7 +92,6 @@ rich = ">=13.8.1"
|
||||||
lxml = ">=5.3.0"
|
lxml = ">=5.3.0"
|
||||||
crosshair-tool = ">=0.0.78"
|
crosshair-tool = ">=0.0.78"
|
||||||
coverage = ">=7.6.4"
|
coverage = ">=7.6.4"
|
||||||
line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13
|
|
||||||
[tool.poetry.group.dev]
|
[tool.poetry.group.dev]
|
||||||
optional = true
|
optional = true
|
||||||
|
|
||||||
|
|
@ -120,7 +119,7 @@ types-gevent = "^24.11.0.20241230"
|
||||||
types-greenlet = "^3.1.0.20241221"
|
types-greenlet = "^3.1.0.20241221"
|
||||||
types-pexpect = "^4.9.0.20241208"
|
types-pexpect = "^4.9.0.20241208"
|
||||||
types-unidiff = "^0.7.0.20240505"
|
types-unidiff = "^0.7.0.20240505"
|
||||||
uv = ">=0.6.2"
|
sqlalchemy = "^2.0.38"
|
||||||
|
|
||||||
[tool.poetry.build]
|
[tool.poetry.build]
|
||||||
script = "codeflash/update_license_version.py"
|
script = "codeflash/update_license_version.py"
|
||||||
|
|
@ -152,7 +151,7 @@ warn_required_dynamic_aliases = true
|
||||||
line-length = 120
|
line-length = 120
|
||||||
fix = true
|
fix = true
|
||||||
show-fixes = true
|
show-fixes = true
|
||||||
exclude = ["code_to_optimize/", "pie_test_set/", "tests/"]
|
exclude = ["code_to_optimize/", "pie_test_set/"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["ALL"]
|
select = ["ALL"]
|
||||||
|
|
@ -164,11 +163,10 @@ ignore = [
|
||||||
"D103",
|
"D103",
|
||||||
"D105",
|
"D105",
|
||||||
"D107",
|
"D107",
|
||||||
"D203", # incorrect-blank-line-before-class (incompatible with D211)
|
|
||||||
"D213", # multi-line-summary-second-line (incompatible with D212)
|
|
||||||
"S101",
|
"S101",
|
||||||
"S603",
|
"S603",
|
||||||
"S607",
|
"S607",
|
||||||
|
"ANN101",
|
||||||
"COM812",
|
"COM812",
|
||||||
"FIX002",
|
"FIX002",
|
||||||
"PLR0912",
|
"PLR0912",
|
||||||
|
|
@ -177,14 +175,13 @@ ignore = [
|
||||||
"TD002",
|
"TD002",
|
||||||
"TD003",
|
"TD003",
|
||||||
"TD004",
|
"TD004",
|
||||||
"PLR2004",
|
"PLR2004"
|
||||||
"UP007" # remove once we drop 3.9 support.
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.flake8-type-checking]
|
[tool.ruff.lint.flake8-type-checking]
|
||||||
strict = true
|
strict = true
|
||||||
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
||||||
runtime-evaluated-decorators = ["pydantic.validate_call", "pydantic.dataclasses.dataclass"]
|
runtime-evaluated-decorators = ["pydantic.validate_call"]
|
||||||
|
|
||||||
[tool.ruff.lint.pep8-naming]
|
[tool.ruff.lint.pep8-naming]
|
||||||
classmethod-decorators = [
|
classmethod-decorators = [
|
||||||
|
|
@ -192,9 +189,6 @@ classmethod-decorators = [
|
||||||
"pydantic.validator",
|
"pydantic.validator",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
split-on-trailing-comma = false
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
skip-magic-trailing-comma = true
|
skip-magic-trailing-comma = true
|
||||||
|
|
@ -217,13 +211,13 @@ initial-content = """
|
||||||
|
|
||||||
|
|
||||||
[tool.codeflash]
|
[tool.codeflash]
|
||||||
module-root = "codeflash"
|
# All paths are relative to this pyproject.toml's directory.
|
||||||
tests-root = "tests"
|
module-root = "code_to_optimize"
|
||||||
|
tests-root = "code_to_optimize/tests"
|
||||||
|
benchmarks-root = "code_to_optimize/tests/pytest/benchmarks"
|
||||||
test-framework = "pytest"
|
test-framework = "pytest"
|
||||||
formatter-cmds = [
|
ignore-paths = []
|
||||||
"uvx ruff check --exit-zero --fix $file",
|
formatter-cmds = ["ruff check --exit-zero --fix $file", "ruff format $file"]
|
||||||
"uvx ruff format $file",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|
|
||||||
8
tests/test_trace_benchmarks.py
Normal file
8
tests/test_trace_benchmarks.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def test_trace_benchmarks():
|
||||||
|
# Test the trace_benchmarks function
|
||||||
|
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||||
|
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks"
|
||||||
|
trace_benchmarks_pytest(benchmarks_root, project_root, ["sorter"])
|
||||||
|
|
@ -3,6 +3,7 @@ import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||||
|
from codeflash.verification.test_results import TestType
|
||||||
from codeflash.verification.verification_utils import TestConfig
|
from codeflash.verification.verification_utils import TestConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,7 +22,7 @@ def test_unit_test_discovery_pytest():
|
||||||
|
|
||||||
def test_benchmark_test_discovery_pytest():
|
def test_benchmark_test_discovery_pytest():
|
||||||
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||||
tests_path = project_path / "tests" / "pytest"
|
tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
|
||||||
test_config = TestConfig(
|
test_config = TestConfig(
|
||||||
tests_root=tests_path,
|
tests_root=tests_path,
|
||||||
project_root_path=project_path,
|
project_root_path=project_path,
|
||||||
|
|
@ -29,9 +30,10 @@ def test_benchmark_test_discovery_pytest():
|
||||||
tests_project_rootdir=tests_path.parent,
|
tests_project_rootdir=tests_path.parent,
|
||||||
)
|
)
|
||||||
tests = discover_unit_tests(test_config)
|
tests = discover_unit_tests(test_config)
|
||||||
print(tests)
|
|
||||||
assert len(tests) > 0
|
assert len(tests) > 0
|
||||||
# print(tests)
|
assert 'bubble_sort.sorter' in tests
|
||||||
|
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
|
||||||
|
assert benchmark_tests == 1
|
||||||
|
|
||||||
|
|
||||||
def test_unit_test_discovery_unittest():
|
def test_unit_test_discovery_unittest():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue