codeflash/codeflash/optimization/optimizer.py
2025-09-26 19:21:29 -07:00

491 lines
24 KiB
Python

from __future__ import annotations
import ast
import copy
import os
import tempfile
import time
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.api.cfapi import send_completion_email
from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
from codeflash.code_utils.git_utils import check_running_in_git_repo
from codeflash.code_utils.git_worktree_utils import (
create_detached_worktree,
create_diff_patch_from_worktree,
create_worktree_snapshot_commit,
remove_worktree,
)
from codeflash.either import is_successful
from codeflash.models.models import ValidCode
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
from codeflash.optimization.function_optimizer import FunctionOptimizer
class Optimizer:
def __init__(self, args: Namespace) -> None:
self.args = args
self.test_cfg = TestConfig(
tests_root=args.tests_root,
tests_project_rootdir=args.test_project_root,
project_root_path=args.project_root,
test_framework=args.test_framework,
pytest_cmd=args.pytest_cmd,
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
)
self.aiservice_client = AiServiceClient()
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
self.replay_tests_dir = None
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP
self.current_function_optimizer: FunctionOptimizer | None = None
self.current_worktree: Path | None = None
self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None
self.patch_files: list[Path] = []
def run_benchmarks(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
"""Run benchmarks for the functions to optimize and collect timing information."""
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] = {}
total_benchmark_timings: dict[BenchmarkKey, float] = {}
if not (hasattr(self.args, "benchmark") and self.args.benchmark and num_optimizable_functions > 0):
return function_benchmark_timings, total_benchmark_timings
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
console.rule()
with progress_bar(
f"Running benchmarks in {self.args.benchmarks_root}", transient=True, revert_to_print=bool(get_pr_number())
):
# Insert decorator
file_path_to_source_code = defaultdict(str)
for file in file_to_funcs_to_optimize:
with file.open("r", encoding="utf8") as f:
file_path_to_source_code[file] = f.read()
try:
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
if trace_file.exists():
trace_file.unlink()
self.replay_tests_dir = Path(
tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)
)
trace_benchmarks_pytest(
self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file
) # Run all tests that use pytest-benchmark
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
if replay_count == 0:
logger.info(
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
)
else:
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
function_to_results = validate_and_format_benchmark_table(
function_benchmark_timings, total_benchmark_timings
)
print_benchmark_table(function_to_results)
except Exception as e:
logger.info(f"Error while tracing existing benchmarks: {e}")
logger.info("Information on existing benchmarks will not be available for this run.")
finally:
# Restore original source code
for file in file_path_to_source_code:
with file.open("w", encoding="utf8") as f:
f.write(file_path_to_source_code[file])
console.rule()
return function_benchmark_timings, total_benchmark_timings
def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
"""Discover functions to optimize."""
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
return get_functions_to_optimize(
optimize_all=self.args.all,
replay_test=self.args.replay_test,
file=self.args.file,
only_get_this_function=self.args.function,
test_cfg=self.test_cfg,
ignore_paths=self.args.ignore_paths,
project_root=self.args.project_root,
module_root=self.args.module_root,
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
enable_async=getattr(self.args, "async", False),
)
def create_function_optimizer(
self,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
original_module_ast: ast.Module | None = None,
original_module_path: Path | None = None,
) -> FunctionOptimizer | None:
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.optimization.function_optimizer import FunctionOptimizer
if function_to_optimize_ast is None and original_module_ast is not None:
function_to_optimize_ast = get_first_top_level_function_or_method_ast(
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
)
if function_to_optimize_ast is None:
logger.info(
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
f"Skipping optimization."
)
return None
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
function_specific_timings = None
if (
hasattr(self.args, "benchmark")
and self.args.benchmark
and function_benchmark_timings
and qualified_name_w_module in function_benchmark_timings
and total_benchmark_timings
):
function_specific_timings = function_benchmark_timings[qualified_name_w_module]
return FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=self.test_cfg,
function_to_optimize_source_code=function_to_optimize_source_code,
function_to_tests=function_to_tests,
function_to_optimize_ast=function_to_optimize_ast,
aiservice_client=self.aiservice_client,
args=self.args,
function_benchmark_timings=function_specific_timings,
total_benchmark_timings=total_benchmark_timings if function_specific_timings else None,
replay_tests_dir=self.replay_tests_dir,
)
def prepare_module_for_optimization(
self, original_module_path: Path
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.static_analysis import analyze_imported_modules
logger.info(f"loading|Examining file {original_module_path!s}")
console.rule()
original_module_code: str = original_module_path.read_text(encoding="utf8")
try:
original_module_ast = ast.parse(original_module_code)
except SyntaxError as e:
logger.warning(f"Syntax error parsing code in {original_module_path}: {e}")
logger.info("Skipping optimization due to file error.")
return None
normalized_original_module_code = ast.unparse(normalize_node(original_module_ast))
validated_original_code: dict[Path, ValidCode] = {
original_module_path: ValidCode(
source_code=original_module_code, normalized_code=normalized_original_module_code
)
}
imported_module_analyses = analyze_imported_modules(
original_module_code, original_module_path, self.args.project_root
)
has_syntax_error = False
for analysis in imported_module_analyses:
callee_original_code = analysis.file_path.read_text(encoding="utf8")
try:
normalized_callee_original_code = normalize_code(callee_original_code)
except SyntaxError as e:
logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}")
logger.info("Skipping optimization due to helper file error.")
has_syntax_error = True
break
validated_original_code[analysis.file_path] = ValidCode(
source_code=callee_original_code, normalized_code=normalized_callee_original_code
)
if has_syntax_error:
return None
return validated_original_code, original_module_ast
def discover_tests(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
from codeflash.discovery.discover_unit_tests import discover_unit_tests
console.rule()
with progress_bar("Discovering existing function tests..."):
start_time = time.time()
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
)
console.rule()
logger.info(
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
)
console.rule()
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
return function_to_tests, num_discovered_tests
def run(self) -> None:
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
ph("cli-optimize-run-start")
logger.info("Running optimizer.")
console.rule()
if not env_utils.ensure_codeflash_api_key():
return
if not env_utils.check_formatter_installed(self.args.formatter_cmds):
return
if self.args.no_draft and is_pr_draft():
logger.warning("PR is in draft mode, skipping optimization")
return
if self.args.worktree:
self.worktree_mode()
cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root))
function_optimizer = None
file_to_funcs_to_optimize, num_optimizable_functions, trace_file_path = self.get_optimizable_functions()
function_benchmark_timings, total_benchmark_timings = self.run_benchmarks(
file_to_funcs_to_optimize, num_optimizable_functions
)
optimizations_found: int = 0
if self.args.test_framework == "pytest":
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
)
try:
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
if num_optimizable_functions == 0:
logger.info("No functions found to optimize. Exiting…")
return
function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize)
if self.args.all:
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root)
for original_module_path in file_to_funcs_to_optimize:
module_prep_result = self.prepare_module_for_optimization(original_module_path)
if module_prep_result is None:
continue
validated_original_code, original_module_ast = module_prep_result
functions_to_optimize = file_to_funcs_to_optimize[original_module_path]
if trace_file_path and trace_file_path.exists() and len(functions_to_optimize) > 1:
try:
from codeflash.benchmarking.function_ranker import FunctionRanker
ranker = FunctionRanker(trace_file_path)
functions_to_optimize = ranker.rank_functions(functions_to_optimize)
logger.info(
f"Ranked {len(functions_to_optimize)} functions by performance impact in {original_module_path}"
)
console.rule()
except Exception as e:
logger.debug(f"Could not rank functions in {original_module_path}: {e}")
for i, function_to_optimize in enumerate(functions_to_optimize):
function_iterator_count = i + 1
logger.info(
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
f"{function_to_optimize.qualified_name}"
)
console.rule()
function_optimizer = None
try:
function_optimizer = self.create_function_optimizer(
function_to_optimize,
function_to_tests=function_to_tests,
function_to_optimize_source_code=validated_original_code[original_module_path].source_code,
function_benchmark_timings=function_benchmark_timings,
total_benchmark_timings=total_benchmark_timings,
original_module_ast=original_module_ast,
original_module_path=original_module_path,
)
if function_optimizer is None:
continue
self.current_function_optimizer = (
function_optimizer # needed to clean up from the outside of this function
)
best_optimization = function_optimizer.optimize_function()
if self.functions_checkpoint:
self.functions_checkpoint.add_function_to_checkpoint(
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
)
if is_successful(best_optimization):
optimizations_found += 1
# create a diff patch for successful optimization
if self.current_worktree:
best_opt = best_optimization.unwrap()
read_writable_code = best_opt.code_context.read_writable_code
relative_file_paths = [
code_string.file_path for code_string in read_writable_code.code_strings
]
metadata = create_diff_patch_from_worktree(
self.current_worktree,
relative_file_paths,
fto_name=function_to_optimize.qualified_name,
metadata_input={},
)
self.patch_files.append(metadata["patch_path"])
if i < len(functions_to_optimize) - 1:
create_worktree_snapshot_commit(
self.current_worktree,
f"Optimizing {functions_to_optimize[i + 1].qualified_name}",
)
else:
logger.warning(best_optimization.failure())
console.rule()
continue
finally:
if function_optimizer is not None:
function_optimizer.executor.shutdown(wait=True)
function_optimizer.cleanup_generated_files()
ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found})
if len(self.patch_files) > 0:
logger.info(
f"Created {len(self.patch_files)} patch(es) ({[str(patch_path) for patch_path in self.patch_files]})"
)
if self.functions_checkpoint:
self.functions_checkpoint.cleanup()
if hasattr(self.args, "command") and self.args.command == "optimize":
self.cleanup_replay_tests()
if optimizations_found == 0:
logger.info("❌ No optimizations found.")
elif self.args.all:
logger.info("✨ All functions have been optimized! ✨")
response = send_completion_email() # TODO: Include more details in the email
if response.ok:
logger.info("✅ Completion email sent successfully.")
else:
logger.warning("⚠️ Failed to send completion email. Status")
finally:
if function_optimizer:
function_optimizer.cleanup_generated_files()
self.cleanup_temporary_paths()
@staticmethod
def find_leftover_instrumented_test_files(test_root: Path) -> list[Path]:
"""Search for all paths within the test_root that match the following patterns.
- 'test.*__perf_test_{0,1}.py'
- 'test_.*__unit_test_{0,1}.py'
- 'test_.*__perfinstrumented.py'
- 'test_.*__perfonlyinstrumented.py'
Returns a list of matching file paths.
"""
import re
pattern = re.compile(
r"(?:test.*__perf_test_\d?\.py|test_.*__unit_test_\d?\.py|test_.*__perfinstrumented\.py|test_.*__perfonlyinstrumented\.py)$"
)
return [
file_path for file_path in test_root.rglob("*") if file_path.is_file() and pattern.match(file_path.name)
]
def cleanup_replay_tests(self) -> None:
if self.replay_tests_dir and self.replay_tests_dir.exists():
logger.debug(f"Cleaning up replay tests directory: {self.replay_tests_dir}")
cleanup_paths([self.replay_tests_dir])
def cleanup_temporary_paths(self) -> None:
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
del get_run_tmp_file.tmpdir
if self.current_worktree:
remove_worktree(self.current_worktree)
return
if self.current_function_optimizer:
self.current_function_optimizer.cleanup_generated_files()
cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir])
def worktree_mode(self) -> None:
if self.current_worktree:
return
if check_running_in_git_repo(self.args.module_root):
worktree_dir = create_detached_worktree(self.args.module_root)
if worktree_dir is None:
logger.warning("Failed to create worktree. Skipping optimization.")
return
self.current_worktree = worktree_dir
self.mutate_args_for_worktree_mode(worktree_dir)
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
saved_args = copy.deepcopy(self.args)
saved_test_cfg = copy.deepcopy(self.test_cfg)
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
project_root = self.args.project_root
module_root = self.args.module_root
relative_module_root = module_root.relative_to(project_root)
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
relative_benchmarks_root = (
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
)
self.args.module_root = worktree_dir / relative_module_root
self.args.project_root = worktree_dir
self.args.test_project_root = worktree_dir
self.args.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
self.test_cfg.project_root_path = worktree_dir
self.test_cfg.tests_project_rootdir = worktree_dir
self.test_cfg.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
if relative_optimized_file is not None:
self.args.file = worktree_dir / relative_optimized_file
def run_with_args(args: Namespace) -> None:
optimizer = None
try:
optimizer = Optimizer(args)
optimizer.run()
except KeyboardInterrupt:
logger.warning("Keyboard interrupt received. Cleaning up and exiting, please wait…")
if optimizer:
optimizer.cleanup_temporary_paths()
raise SystemExit from None