mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge pull request #2055 from codeflash-ai/perf/defer-cli-imports
perf: defer cli.py imports for 7.7x faster --help
This commit is contained in:
commit
72a41a5665
27 changed files with 606 additions and 110 deletions
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
|
|
@ -199,6 +199,8 @@ jobs:
|
|||
run: |
|
||||
uv run ruff check --fix . || true
|
||||
uv run ruff format .
|
||||
# uv-dynamic-versioning rewrites version.py on every `uv run` — discard those changes
|
||||
git checkout HEAD -- codeflash/version.py codeflash-benchmark/codeflash_benchmark/version.py 2>/dev/null || true
|
||||
|
||||
- name: Commit and push fixes
|
||||
run: |
|
||||
|
|
|
|||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
72
benchmarks/bench_cli_startup.py
Normal file
72
benchmarks/bench_cli_startup.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Benchmark CLI startup latency for codeflash compare --script mode.
|
||||
|
||||
Run from a worktree root. Installs deps via uv sync, then times several
|
||||
CLI entry points and writes a JSON file mapping command names to median
|
||||
wall-clock seconds.
|
||||
|
||||
Usage:
|
||||
codeflash compare main codeflash/optimize \
|
||||
--script "python benchmarks/bench_cli_startup.py" \
|
||||
--script-output benchmarks/results.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
WARMUP = 3
|
||||
RUNS = 30
|
||||
OUTPUT = os.environ.get("BENCH_OUTPUT", "benchmarks/results.json")
|
||||
|
||||
COMMANDS: dict[str, list[str]] = {
|
||||
"version": ["uv", "run", "codeflash", "--version"],
|
||||
"help": ["uv", "run", "codeflash", "--help"],
|
||||
"auth_status": ["uv", "run", "codeflash", "auth", "status"],
|
||||
"compare_help": ["uv", "run", "codeflash", "compare", "--help"],
|
||||
}
|
||||
|
||||
|
||||
def measure(cmd: list[str], warmup: int = WARMUP, runs: int = RUNS) -> float:
|
||||
"""Return median wall-clock seconds for *cmd* over *runs* iterations."""
|
||||
env = {**os.environ, "CODEFLASH_API_KEY": "bench_dummy_key"}
|
||||
for _ in range(warmup):
|
||||
subprocess.run(cmd, capture_output=True, check=False, env=env)
|
||||
|
||||
times: list[float] = []
|
||||
for _ in range(runs):
|
||||
t0 = time.perf_counter()
|
||||
subprocess.run(cmd, capture_output=True, check=False, env=env)
|
||||
times.append(time.perf_counter() - t0)
|
||||
|
||||
times.sort()
|
||||
mid = len(times) // 2
|
||||
return times[mid] if len(times) % 2 else (times[mid - 1] + times[mid]) / 2
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Ensure deps are installed in the worktree
|
||||
subprocess.run(["uv", "sync"], check=True, capture_output=True)
|
||||
|
||||
results: dict[str, float] = {}
|
||||
for name, cmd in COMMANDS.items():
|
||||
print(f" {name}: ", end="", flush=True)
|
||||
median = measure(cmd)
|
||||
results[name] = round(median, 4)
|
||||
print(f"{median * 1000:.0f} ms")
|
||||
|
||||
# Total = sum of medians (useful for a single summary number)
|
||||
results["__total__"] = round(sum(results.values()), 4)
|
||||
|
||||
output_path = Path(OUTPUT)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nResults written to {OUTPUT}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
|||
|
|
@ -5,15 +5,6 @@ from argparse import SUPPRESS, ArgumentParser, Namespace
|
|||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.cli_cmds import logging_config
|
||||
from codeflash.cli_cmds.console import apologize_and_exit, logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.languages.test_framework import set_current_test_framework
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
|
||||
def parse_args() -> Namespace:
|
||||
parser = _build_parser()
|
||||
|
|
@ -30,12 +21,17 @@ def parse_args() -> Namespace:
|
|||
|
||||
|
||||
def process_and_validate_cmd_args(args: Namespace) -> Namespace:
|
||||
from codeflash.cli_cmds import logging_config
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.git_utils import (
|
||||
check_running_in_git_repo,
|
||||
confirm_proceeding_with_no_git_repo,
|
||||
get_repo_owner_and_name,
|
||||
)
|
||||
from codeflash.code_utils.github_utils import require_github_app_or_exit
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
if args.server:
|
||||
os.environ["CODEFLASH_AIS_SERVER"] = args.server
|
||||
|
|
@ -85,6 +81,12 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
|
|||
|
||||
|
||||
def process_pyproject_config(args: Namespace) -> Namespace:
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.languages.test_framework import set_current_test_framework
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
try:
|
||||
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
|
||||
except ValueError as e:
|
||||
|
|
@ -222,6 +224,9 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
|
|||
|
||||
|
||||
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
||||
from codeflash.cli_cmds.console import apologize_and_exit, logger
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
|
||||
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
|
||||
no_pr = getattr(args, "no_pr", False)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,17 +2,17 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.cli_cmds.oauth_handler import perform_oauth_signin
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc
|
||||
from codeflash.either import is_successful
|
||||
|
||||
|
||||
def auth_login() -> None:
|
||||
"""Perform OAuth login and save the API key."""
|
||||
import click
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.cli_cmds.oauth_handler import perform_oauth_signin
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc
|
||||
from codeflash.either import is_successful
|
||||
|
||||
try:
|
||||
existing_api_key = get_codeflash_api_key()
|
||||
except OSError:
|
||||
|
|
@ -41,6 +41,9 @@ def auth_login() -> None:
|
|||
|
||||
def auth_status() -> None:
|
||||
"""Check and display current authentication status."""
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
|
||||
try:
|
||||
api_key = get_codeflash_api_key()
|
||||
except OSError:
|
||||
|
|
|
|||
64
codeflash/code_utils/_libcst_cache.py
Normal file
64
codeflash/code_utils/_libcst_cache.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""Cache libcst visitor dispatch table construction.
|
||||
|
||||
libcst's ``MatcherDecoratableTransformer`` and
|
||||
``MatcherDecoratableVisitor`` rebuild visitor dispatch tables on
|
||||
every instantiation by iterating ``dir(self)`` (~600 attributes)
|
||||
and calling ``getattr`` + ``inspect.ismethod`` on each. The
|
||||
results depend only on the class, not the instance, so caching
|
||||
by ``type(obj)`` is safe.
|
||||
|
||||
Import this module before any libcst visitors are instantiated
|
||||
to install the cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import libcst.matchers._visitors as _mv
|
||||
|
||||
_visit_cache: dict[type, Any] = {}
|
||||
_leave_cache: dict[type, Any] = {}
|
||||
_matchers_cache: dict[type, Any] = {}
|
||||
|
||||
_original_visit = _mv._gather_constructed_visit_funcs # noqa: SLF001
|
||||
_original_leave = _mv._gather_constructed_leave_funcs # noqa: SLF001
|
||||
_original_matchers = _mv._gather_matchers # noqa: SLF001
|
||||
|
||||
|
||||
def _cached_visit(obj: object) -> Any:
|
||||
"""Return cached visit-function dispatch table for the object's class."""
|
||||
cls = type(obj)
|
||||
try:
|
||||
return _visit_cache[cls]
|
||||
except KeyError:
|
||||
result = _original_visit(obj)
|
||||
_visit_cache[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
def _cached_leave(obj: object) -> Any:
|
||||
"""Return cached leave-function dispatch table for the object's class."""
|
||||
cls = type(obj)
|
||||
try:
|
||||
return _leave_cache[cls]
|
||||
except KeyError:
|
||||
result = _original_leave(obj)
|
||||
_leave_cache[cls] = result
|
||||
return result
|
||||
|
||||
|
||||
def _cached_matchers(obj: object) -> Any:
|
||||
"""Return cached matcher dispatch table for the object's class."""
|
||||
cls = type(obj)
|
||||
try:
|
||||
return dict(_matchers_cache[cls])
|
||||
except KeyError:
|
||||
result = _original_matchers(obj)
|
||||
_matchers_cache[cls] = result
|
||||
return dict(result)
|
||||
|
||||
|
||||
_mv._gather_constructed_visit_funcs = _cached_visit # noqa: SLF001
|
||||
_mv._gather_constructed_leave_funcs = _cached_leave # noqa: SLF001
|
||||
_mv._gather_matchers = _cached_matchers # noqa: SLF001
|
||||
|
|
@ -9,17 +9,16 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.formatter import format_code
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
|
||||
from codeflash.languages.registry import get_language_support_by_common_formatters
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
def check_formatter_installed(
|
||||
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
|
||||
) -> bool:
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.formatter import format_code
|
||||
from codeflash.languages.registry import get_language_support_by_common_formatters
|
||||
|
||||
if not formatter_cmds or formatter_cmds[0] == "disabled":
|
||||
return True
|
||||
first_cmd = formatter_cmds[0]
|
||||
|
|
@ -69,6 +68,8 @@ def check_formatter_installed(
|
|||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_codeflash_api_key() -> str:
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
# Check environment variable first
|
||||
env_api_key = os.environ.get("CODEFLASH_API_KEY")
|
||||
shell_api_key = read_api_key_from_shell_config()
|
||||
|
|
@ -96,7 +97,8 @@ def get_codeflash_api_key() -> str:
|
|||
# Prefer the shell configuration over environment variables for lsp,
|
||||
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
|
||||
# within the same process, the environment variable could become outdated.
|
||||
api_key = shell_api_key or env_api_key if is_LSP_enabled() else env_api_key or shell_api_key
|
||||
is_lsp = os.getenv("CODEFLASH_LSP", default="false").lower() == "true"
|
||||
api_key = shell_api_key or env_api_key if is_lsp else env_api_key or shell_api_key
|
||||
|
||||
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#manual-setup]." # noqa
|
||||
if not api_key:
|
||||
|
|
@ -106,6 +108,8 @@ def get_codeflash_api_key() -> str:
|
|||
f"{api_secret_docs_message}"
|
||||
)
|
||||
if is_repo_a_fork():
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
|
||||
msg = (
|
||||
"Codeflash API key not detected in your environment. It appears you're running Codeflash from a GitHub fork.\n"
|
||||
"For external contributors, please ensure you've added your own API key to your fork's repository secrets and set it as the CODEFLASH_API_KEY environment variable.\n"
|
||||
|
|
@ -124,6 +128,8 @@ def get_codeflash_api_key() -> str:
|
|||
|
||||
|
||||
def ensure_codeflash_api_key() -> bool:
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
try:
|
||||
get_codeflash_api_key()
|
||||
except OSError:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.either import Failure, Success
|
||||
|
||||
|
|
@ -41,6 +40,8 @@ def is_powershell() -> bool:
|
|||
2. COMSPEC pointing to powershell.exe
|
||||
3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell)
|
||||
"""
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
if os.name != "nt":
|
||||
return False
|
||||
|
||||
|
|
@ -72,6 +73,8 @@ def is_powershell() -> bool:
|
|||
|
||||
def read_api_key_from_shell_config() -> Optional[str]:
|
||||
"""Read API key from shell configuration file."""
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
# Ensure shell_rc_path is a Path object for consistent handling
|
||||
if not isinstance(shell_rc_path, Path):
|
||||
|
|
@ -127,6 +130,8 @@ def get_api_key_export_line(api_key: str) -> str:
|
|||
|
||||
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
|
||||
"""Save API key to the appropriate shell configuration file."""
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
# Ensure shell_rc_path is a Path object for consistent handling
|
||||
if not isinstance(shell_rc_path, Path):
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from rich.syntax import Syntax
|
|||
from rich.text import Text
|
||||
from rich.tree import Tree
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
|
||||
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
|
||||
from codeflash.benchmarking.utils import process_benchmark_data
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any
|
|||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.code_utils.config_consts import (
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.languages import current_language
|
||||
from codeflash.languages.base import Language
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from libcst.codemod import CodemodContext
|
|||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW
|
||||
from codeflash.languages.base import Language
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, TypeVar
|
|||
import libcst as cst
|
||||
from libcst.metadata import PositionProvider
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.config_parser import find_conftest_files
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import libcst as cst
|
|||
from libcst import MetadataWrapper
|
||||
from libcst.metadata import PositionProvider
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Union
|
|||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any
|
|||
|
||||
import libcst as cst
|
||||
|
||||
import codeflash.code_utils._libcst_cache # noqa: F401
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
|
|
|
|||
|
|
@ -29,9 +29,36 @@ def main() -> None:
|
|||
print(f"Codeflash version {__version__}")
|
||||
return
|
||||
|
||||
from codeflash.cli_cmds.cli import parse_args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
# Auth commands skip banner, telemetry, and version check entirely
|
||||
if args.command == "auth":
|
||||
from codeflash.cli_cmds.cmd_auth import auth_login, auth_status
|
||||
|
||||
if args.auth_command == "login":
|
||||
auth_login()
|
||||
elif args.auth_command == "status":
|
||||
auth_status()
|
||||
else:
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
|
||||
exit_with_message("Usage: codeflash auth {login,status}", error_on_exit=True)
|
||||
return
|
||||
|
||||
# Compare command only needs its own imports
|
||||
if args.command == "compare":
|
||||
print_codeflash_banner()
|
||||
from codeflash.cli_cmds.cmd_compare import run_compare
|
||||
|
||||
run_compare(args)
|
||||
return
|
||||
|
||||
# All other commands need the full stack
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
|
||||
from codeflash.cli_cmds.cli import process_pyproject_config
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
|
@ -39,11 +66,7 @@ def main() -> None:
|
|||
from codeflash.telemetry import posthog_cf
|
||||
from codeflash.telemetry.sentry import init_sentry
|
||||
|
||||
args = parse_args()
|
||||
if args.command != "auth":
|
||||
print_codeflash_banner()
|
||||
|
||||
# Check for newer version for all commands
|
||||
print_codeflash_banner()
|
||||
check_for_newer_minor_version()
|
||||
|
||||
if args.command:
|
||||
|
|
@ -54,18 +77,7 @@ def main() -> None:
|
|||
init_sentry(enabled=not disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not disable_telemetry)
|
||||
|
||||
if args.command == "auth":
|
||||
from codeflash.cli_cmds.cmd_auth import auth_login, auth_status
|
||||
|
||||
if args.auth_command == "login":
|
||||
auth_login()
|
||||
elif args.auth_command == "status":
|
||||
auth_status()
|
||||
else:
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
|
||||
exit_with_message("Usage: codeflash auth {login,status}", error_on_exit=True)
|
||||
elif args.command == "init":
|
||||
if args.command == "init":
|
||||
from codeflash.cli_cmds.cmd_init import init_codeflash
|
||||
|
||||
init_codeflash()
|
||||
|
|
@ -77,10 +89,6 @@ def main() -> None:
|
|||
from codeflash.cli_cmds.extension import install_vscode_extension
|
||||
|
||||
install_vscode_extension()
|
||||
elif args.command == "compare":
|
||||
from codeflash.cli_cmds.cmd_compare import run_compare
|
||||
|
||||
run_compare(args)
|
||||
elif args.command == "optimize":
|
||||
from codeflash.tracer import main as tracer_main
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +1,26 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import re
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Collection
|
||||
from enum import Enum, IntEnum
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
|
||||
|
||||
import libcst as cst
|
||||
from rich.tree import Tree
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
|
||||
from codeflash.lsp.lsp_message import LspMarkdownMessage
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
import enum
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Collection
|
||||
from enum import Enum, IntEnum
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, NamedTuple, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code
|
||||
from codeflash.code_utils.env_utils import is_end_to_end
|
||||
from codeflash.verification.comparator import comparator
|
||||
import libcst as cst
|
||||
from rich.tree import Tree
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -254,6 +243,8 @@ class CodeString(BaseModel):
|
|||
def validate_code_syntax(self) -> CodeString:
|
||||
"""Validate code syntax for the specified language."""
|
||||
if self.language == "python":
|
||||
from codeflash.code_utils.code_utils import validate_python_code
|
||||
|
||||
validate_python_code(self.code)
|
||||
else:
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
|
@ -267,6 +258,8 @@ class CodeString(BaseModel):
|
|||
|
||||
def get_comment_prefix(file_path: Path) -> str:
|
||||
"""Get the comment prefix for a given language."""
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
support = get_language_support(file_path)
|
||||
return support.comment_prefix
|
||||
|
||||
|
|
@ -565,6 +558,8 @@ class CandidateEvaluationContext:
|
|||
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
|
||||
|
||||
# Update to shorter code if this candidate has a shorter diff
|
||||
from codeflash.code_utils.code_utils import diff_length
|
||||
|
||||
new_diff_len = diff_length(candidate.source_code.flat, original_flat_code)
|
||||
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
|
||||
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
|
||||
|
|
@ -574,6 +569,8 @@ class CandidateEvaluationContext:
|
|||
self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str
|
||||
) -> None:
|
||||
"""Register a new candidate that hasn't been seen before."""
|
||||
from codeflash.code_utils.code_utils import diff_length
|
||||
|
||||
self.ast_code_to_id[normalized_code] = {
|
||||
"optimization_id": candidate.optimization_id,
|
||||
"shorter_source_code": candidate.source_code,
|
||||
|
|
@ -670,6 +667,9 @@ class CoverageData:
|
|||
def log_coverage(self) -> None:
|
||||
from rich.tree import Tree
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.env_utils import is_end_to_end
|
||||
|
||||
tree = Tree("Test Coverage Results")
|
||||
tree.add(f"Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%")
|
||||
if self.dependent_func_coverage:
|
||||
|
|
@ -769,12 +769,16 @@ class InvocationId:
|
|||
)
|
||||
|
||||
def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]:
|
||||
import libcst as cst
|
||||
|
||||
for stmt in class_node.body.body:
|
||||
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name:
|
||||
return stmt
|
||||
return None
|
||||
|
||||
def get_src_code(self, test_path: Path) -> Optional[str]:
|
||||
import libcst as cst
|
||||
|
||||
if not test_path.exists():
|
||||
return None
|
||||
try:
|
||||
|
|
@ -856,6 +860,8 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
unique_id = function_test_invocation.unique_invocation_loop_id
|
||||
test_result_idx = self.test_result_idx
|
||||
if unique_id in test_result_idx:
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE, logger
|
||||
|
||||
if DEBUG_MODE:
|
||||
logger.warning(f"Test result with id {unique_id} already exists. SKIPPING")
|
||||
return
|
||||
|
|
@ -876,6 +882,8 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path
|
||||
) -> dict[BenchmarkKey, TestResults]:
|
||||
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
|
||||
test_results_by_benchmark = defaultdict(TestResults)
|
||||
benchmark_module_path = {}
|
||||
for benchmark_key in benchmark_keys:
|
||||
|
|
@ -929,9 +937,17 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
|
||||
@staticmethod
|
||||
def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
|
||||
from rich.tree import Tree
|
||||
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
tree = Tree(title)
|
||||
|
||||
if is_LSP_enabled():
|
||||
from codeflash.cli_cmds.console import lsp_log
|
||||
from codeflash.lsp.helpers import report_to_markdown_table
|
||||
from codeflash.lsp.lsp_message import LspMarkdownMessage
|
||||
|
||||
# Build markdown table
|
||||
markdown = report_to_markdown_table(report, title)
|
||||
lsp_log(LspMarkdownMessage(markdown=markdown))
|
||||
|
|
@ -946,6 +962,8 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
return tree
|
||||
|
||||
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
# Efficient single traversal, directly accumulating into a dict.
|
||||
# can track mins here and only sums can be return in total_passed_runtime
|
||||
by_id: dict[InvocationId, list[int]] = {}
|
||||
|
|
@ -1025,6 +1043,8 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
return bool(self.test_results)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
from codeflash.verification.comparator import comparator
|
||||
|
||||
# Unordered comparison
|
||||
if type(self) is not type(other):
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -74,6 +74,27 @@ _DICT_KEYS_TYPE = type({}.keys())
|
|||
_DICT_VALUES_TYPE = type({}.values())
|
||||
_DICT_ITEMS_TYPE = type({}.items())
|
||||
|
||||
_IDENTITY_EQ_TYPES: frozenset[type[Any]] = frozenset(
|
||||
{
|
||||
int,
|
||||
bool,
|
||||
complex,
|
||||
type(None),
|
||||
type(Ellipsis),
|
||||
decimal.Decimal,
|
||||
set,
|
||||
bytes,
|
||||
bytearray,
|
||||
memoryview,
|
||||
frozenset,
|
||||
type,
|
||||
range,
|
||||
slice,
|
||||
OrderedDict,
|
||||
types.GenericAlias,
|
||||
}
|
||||
)
|
||||
|
||||
_EQUALITY_TYPES = (
|
||||
int,
|
||||
bool,
|
||||
|
|
@ -184,32 +205,61 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
|
|||
|
||||
return False
|
||||
|
||||
if type(orig) is not type(new):
|
||||
type_obj = type(orig)
|
||||
new_type_obj = type(new)
|
||||
orig_type = type(orig)
|
||||
if orig_type is not type(new):
|
||||
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
|
||||
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
|
||||
if orig_type.__name__ != type(new).__name__ or orig_type.__qualname__ != type(new).__qualname__:
|
||||
return False
|
||||
|
||||
# Fast-path: type identity checks for the most common return-value types.
|
||||
# `orig_type is T` is a single pointer comparison — cheaper than frozenset hash
|
||||
# lookup or isinstance MRO traversal — and these 4 types dominate real workloads.
|
||||
if orig_type is str:
|
||||
if orig == new:
|
||||
return True
|
||||
if _is_temp_path(orig) and _is_temp_path(new):
|
||||
return _normalize_temp_path(orig) == _normalize_temp_path(new)
|
||||
return False
|
||||
if orig_type is list or orig_type is tuple:
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||
if orig_type is dict:
|
||||
if superset_obj:
|
||||
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
for key in orig:
|
||||
if key not in new:
|
||||
return False
|
||||
if not comparator(orig[key], new[key], superset_obj):
|
||||
return False
|
||||
return True
|
||||
if orig_type is float:
|
||||
if math.isnan(orig) and math.isnan(new):
|
||||
return True
|
||||
return math.isclose(orig, new)
|
||||
# O(1) frozenset lookup for remaining common types (int, bool, None, Decimal, etc.)
|
||||
if orig_type in _IDENTITY_EQ_TYPES:
|
||||
return orig == new
|
||||
|
||||
# Slower isinstance path for subclasses (deque, ChainMap, etc.)
|
||||
if isinstance(orig, (list, tuple, deque, ChainMap)):
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||
|
||||
# Handle strings separately to normalize temp paths
|
||||
# Handle string subclasses separately to normalize temp paths
|
||||
if isinstance(orig, str):
|
||||
if orig == new:
|
||||
return True
|
||||
# If strings differ, check if they're temp paths that differ only in session number
|
||||
if _is_temp_path(orig) and _is_temp_path(new):
|
||||
return _normalize_temp_path(orig) == _normalize_temp_path(new)
|
||||
return False
|
||||
|
||||
# enum.Enum subclasses and UnionType fall through from the frozenset fast-path
|
||||
if isinstance(orig, _EQUALITY_TYPES):
|
||||
return orig == new
|
||||
if isinstance(orig, float):
|
||||
if math.isnan(orig) and math.isnan(new):
|
||||
return True
|
||||
return math.isclose(orig, new)
|
||||
|
||||
# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
|
||||
if isinstance(orig, weakref.ref):
|
||||
|
|
|
|||
0
tests/benchmarks/__init__.py
Normal file
0
tests/benchmarks/__init__.py
Normal file
|
|
@ -1,31 +1,18 @@
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
def test_benchmark_extract(benchmark) -> None:
|
||||
file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root=(file_path / "tests").resolve(),
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path.cwd(),
|
||||
)
|
||||
)
|
||||
project_root = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="replace_function_and_helpers_with_optimized_code",
|
||||
file_path=file_path / "languages" / "function_optimizer.py",
|
||||
file_path=project_root / "languages" / "function_optimizer.py",
|
||||
parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
benchmark(get_code_optimization_context, function_to_optimize, opt.args.project_root)
|
||||
benchmark(get_code_optimization_context, function_to_optimize, project_root)
|
||||
|
|
|
|||
133
tests/benchmarks/test_benchmark_comparator.py
Normal file
133
tests/benchmarks/test_benchmark_comparator.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
"""Benchmark comparator type dispatch performance.
|
||||
|
||||
Exercises the fast-path frozenset lookup vs isinstance MRO traversal
|
||||
across realistic return value shapes: primitives, nested containers,
|
||||
and mixed-type structures typical of real optimization verification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from codeflash.verification.comparator import comparator
|
||||
|
||||
# --- Test data: realistic return value shapes ---
|
||||
|
||||
# 1. Flat primitives (int, bool, None, str, float, bytes) — the fast-path sweet spot
|
||||
_PRIMITIVES_A = [
|
||||
42,
|
||||
True,
|
||||
None,
|
||||
3.14,
|
||||
"hello",
|
||||
b"bytes",
|
||||
0,
|
||||
False,
|
||||
"",
|
||||
1.0,
|
||||
-1,
|
||||
None,
|
||||
True,
|
||||
99,
|
||||
"world",
|
||||
b"\x00\x01",
|
||||
2**31,
|
||||
0.0,
|
||||
False,
|
||||
None,
|
||||
]
|
||||
_PRIMITIVES_B = list(_PRIMITIVES_A)
|
||||
|
||||
# 2. Nested dict of lists (common return value shape: API responses, parsed configs)
|
||||
_NESTED_DICT_A = {
|
||||
"users": [{"id": i, "name": f"user_{i}", "active": i % 2 == 0, "score": i * 1.5} for i in range(50)],
|
||||
"metadata": {"total": 50, "page": 1, "has_next": True},
|
||||
"tags": [f"tag_{i}" for i in range(20)],
|
||||
"config": {"timeout": 30, "retries": 3, "debug": False, "threshold": Decimal("0.95")},
|
||||
}
|
||||
_NESTED_DICT_B = {
|
||||
"users": [{"id": i, "name": f"user_{i}", "active": i % 2 == 0, "score": i * 1.5} for i in range(50)],
|
||||
"metadata": {"total": 50, "page": 1, "has_next": True},
|
||||
"tags": [f"tag_{i}" for i in range(20)],
|
||||
"config": {"timeout": 30, "retries": 3, "debug": False, "threshold": Decimal("0.95")},
|
||||
}
|
||||
|
||||
# 3. List of tuples (common: database rows, CSV data)
|
||||
_ROWS_A = [(i, f"row_{i}", i * 0.1, i % 3 == 0, None if i % 5 == 0 else i) for i in range(200)]
|
||||
_ROWS_B = [(i, f"row_{i}", i * 0.1, i % 3 == 0, None if i % 5 == 0 else i) for i in range(200)]
|
||||
|
||||
|
||||
# 4. Deeply nested structure (worst case for recursive comparator)
|
||||
def _make_deep(depth: int) -> dict:
|
||||
if depth == 0:
|
||||
return {"leaf": True, "value": 42, "items": [1, 2, 3], "label": "end"}
|
||||
return {"level": depth, "child": _make_deep(depth - 1), "siblings": list(range(depth))}
|
||||
|
||||
|
||||
_DEEP_A = _make_deep(15)
|
||||
_DEEP_B = _make_deep(15)
|
||||
|
||||
# 5. Mixed identity types (frozenset, range, slice, OrderedDict, bytes, complex)
|
||||
_IDENTITY_TYPES_A = [
|
||||
frozenset({1, 2, 3}),
|
||||
range(100),
|
||||
complex(1, 2),
|
||||
Decimal("3.14"),
|
||||
OrderedDict(a=1, b=2),
|
||||
b"binary",
|
||||
bytearray(b"mutable"),
|
||||
memoryview(b"view"),
|
||||
type(None),
|
||||
True,
|
||||
42,
|
||||
None,
|
||||
] * 10
|
||||
_IDENTITY_TYPES_B = list(_IDENTITY_TYPES_A)
|
||||
|
||||
|
||||
def _compare_all_primitives() -> None:
|
||||
for a, b in zip(_PRIMITIVES_A, _PRIMITIVES_B):
|
||||
comparator(a, b)
|
||||
|
||||
|
||||
def _compare_nested_dict() -> None:
|
||||
comparator(_NESTED_DICT_A, _NESTED_DICT_B)
|
||||
|
||||
|
||||
def _compare_rows() -> None:
|
||||
comparator(_ROWS_A, _ROWS_B)
|
||||
|
||||
|
||||
def _compare_deep() -> None:
|
||||
comparator(_DEEP_A, _DEEP_B)
|
||||
|
||||
|
||||
def _compare_identity_types() -> None:
|
||||
for a, b in zip(_IDENTITY_TYPES_A, _IDENTITY_TYPES_B):
|
||||
comparator(a, b)
|
||||
|
||||
|
||||
def test_benchmark_comparator_primitives(benchmark) -> None:
|
||||
"""20 flat primitive comparisons (int, bool, None, str, float, bytes)."""
|
||||
benchmark(_compare_all_primitives)
|
||||
|
||||
|
||||
def test_benchmark_comparator_nested_dict(benchmark) -> None:
|
||||
"""Nested dict with 50-element user list, metadata, tags, config."""
|
||||
benchmark(_compare_nested_dict)
|
||||
|
||||
|
||||
def test_benchmark_comparator_rows(benchmark) -> None:
|
||||
"""200 tuples of (int, str, float, bool, Optional[int])."""
|
||||
benchmark(_compare_rows)
|
||||
|
||||
|
||||
def test_benchmark_comparator_deep(benchmark) -> None:
|
||||
"""15-level deep nested dict structure."""
|
||||
benchmark(_compare_deep)
|
||||
|
||||
|
||||
def test_benchmark_comparator_identity_types(benchmark) -> None:
|
||||
"""120 frozenset/range/complex/Decimal/OrderedDict/bytes comparisons."""
|
||||
benchmark(_compare_identity_types)
|
||||
75
tests/benchmarks/test_benchmark_libcst_multi_file.py
Normal file
75
tests/benchmarks/test_benchmark_libcst_multi_file.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Benchmark libcst visitor performance across many files.
|
||||
|
||||
Exercises the visitor-heavy codepaths that benefit from the libcst dispatch
|
||||
table cache: discover_functions + get_code_optimization_context on multiple
|
||||
real source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
# Real source files from the codeflash codebase, chosen for size and visitor diversity.
|
||||
_CODEFLASH_ROOT = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
|
||||
_SOURCE_FILES: list[Path] = [
|
||||
_CODEFLASH_ROOT / "languages" / "function_optimizer.py",
|
||||
_CODEFLASH_ROOT / "languages" / "python" / "context" / "code_context_extractor.py",
|
||||
_CODEFLASH_ROOT / "languages" / "python" / "support.py",
|
||||
_CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_extractor.py",
|
||||
_CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_replacer.py",
|
||||
_CODEFLASH_ROOT / "code_utils" / "instrument_existing_tests.py",
|
||||
_CODEFLASH_ROOT / "benchmarking" / "compare.py",
|
||||
_CODEFLASH_ROOT / "models" / "models.py",
|
||||
_CODEFLASH_ROOT / "discovery" / "discover_unit_tests.py",
|
||||
_CODEFLASH_ROOT / "languages" / "base.py",
|
||||
]
|
||||
|
||||
# For each file, pick one top-level function to extract context for.
|
||||
# (class, function_name) — class=None means module-level.
|
||||
_TARGETS: list[tuple[Path, str | None, str]] = [
|
||||
(_SOURCE_FILES[0], "FunctionOptimizer", "replace_function_and_helpers_with_optimized_code"),
|
||||
(_SOURCE_FILES[1], None, "get_code_optimization_context"),
|
||||
(_SOURCE_FILES[2], "PythonSupport", "discover_functions"),
|
||||
(_SOURCE_FILES[3], None, "add_global_assignments"),
|
||||
(_SOURCE_FILES[4], None, "replace_functions_in_file"),
|
||||
(_SOURCE_FILES[5], None, "inject_profiling_into_existing_test"),
|
||||
(_SOURCE_FILES[6], None, "compare_branches"),
|
||||
(_SOURCE_FILES[7], None, "get_comment_prefix"),
|
||||
(_SOURCE_FILES[8], None, "discover_unit_tests"),
|
||||
(_SOURCE_FILES[9], None, "convert_parents_to_tuple"),
|
||||
]
|
||||
|
||||
|
||||
def _discover_all() -> None:
|
||||
"""Run discover_functions on all source files."""
|
||||
ps = PythonSupport()
|
||||
for file_path in _SOURCE_FILES:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
ps.discover_functions(source=source, file_path=file_path)
|
||||
|
||||
|
||||
def _extract_all_contexts() -> None:
|
||||
"""Run get_code_optimization_context on every target function."""
|
||||
project_root = _CODEFLASH_ROOT.parent
|
||||
for file_path, class_name, func_name in _TARGETS:
|
||||
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
|
||||
fto = FunctionToOptimize(
|
||||
function_name=func_name, file_path=file_path, parents=parents, starting_line=None, ending_line=None
|
||||
)
|
||||
get_code_optimization_context(fto, project_root)
|
||||
|
||||
|
||||
def test_benchmark_discover_functions_multi_file(benchmark) -> None:
|
||||
"""Discover functions across 10 source files."""
|
||||
benchmark(_discover_all)
|
||||
|
||||
|
||||
def test_benchmark_extract_context_multi_file(benchmark) -> None:
|
||||
"""Extract code optimization context for 10 functions across 10 files."""
|
||||
benchmark(_extract_all_contexts)
|
||||
56
tests/benchmarks/test_benchmark_libcst_pipeline.py
Normal file
56
tests/benchmarks/test_benchmark_libcst_pipeline.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Benchmark the full libcst-heavy pipeline on a single file.
|
||||
|
||||
Runs discover → extract context → replace functions → add global assignments
|
||||
in sequence, exercising ~15 distinct visitor/transformer classes in one pass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.languages.python.static_analysis.code_extractor import add_global_assignments
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
_CODEFLASH_ROOT = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
_PROJECT_ROOT = _CODEFLASH_ROOT.parent
|
||||
|
||||
# Target: a real, non-trivial file with classes and module-level functions.
|
||||
_TARGET_FILE = _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_extractor.py"
|
||||
_TARGET_FUNC = "add_global_assignments"
|
||||
|
||||
# A second file to serve as "optimized" source for replace/merge steps.
|
||||
_SECOND_FILE = _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_replacer.py"
|
||||
|
||||
|
||||
def _run_pipeline() -> None:
|
||||
"""Simulate a single-file optimization pass through the full visitor pipeline."""
|
||||
source = _TARGET_FILE.read_text(encoding="utf-8")
|
||||
source2 = _SECOND_FILE.read_text(encoding="utf-8")
|
||||
|
||||
# 1. Discover functions (FunctionVisitor + MetadataWrapper)
|
||||
ps = PythonSupport()
|
||||
functions = ps.discover_functions(source=source, file_path=_TARGET_FILE)
|
||||
|
||||
# 2. Extract code optimization context (multiple collectors + dependency resolver)
|
||||
fto = FunctionToOptimize(
|
||||
function_name=_TARGET_FUNC, file_path=_TARGET_FILE, parents=[], starting_line=None, ending_line=None
|
||||
)
|
||||
get_code_optimization_context(fto, _PROJECT_ROOT)
|
||||
|
||||
# 3. Replace functions (GlobalFunctionCollector + GlobalFunctionTransformer)
|
||||
# Use a class method from discovered functions if available, else module-level.
|
||||
func_names = [_TARGET_FUNC]
|
||||
replace_functions_in_file(
|
||||
source_code=source, original_function_names=func_names, optimized_code=source2, preexisting_objects=set()
|
||||
)
|
||||
|
||||
# 4. Add global assignments (6 visitors/transformers)
|
||||
add_global_assignments(source2, source)
|
||||
|
||||
|
||||
def test_benchmark_full_pipeline(benchmark) -> None:
|
||||
"""Full discover → extract → replace → merge pipeline on one file."""
|
||||
benchmark(_run_pipeline)
|
||||
|
|
@ -9,8 +9,8 @@ from codeflash.either import Success
|
|||
|
||||
|
||||
class TestAuthLogin:
|
||||
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.console")
|
||||
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key")
|
||||
@patch("codeflash.cli_cmds.console.console")
|
||||
def test_existing_api_key_skips_oauth(self, mock_console: MagicMock, mock_get_key: MagicMock) -> None:
|
||||
mock_get_key.return_value = "cf-test1234abcd"
|
||||
|
||||
|
|
@ -21,19 +21,19 @@ class TestAuthLogin:
|
|||
"To re-authenticate, unset [bold]CODEFLASH_API_KEY[/bold] and run this command again."
|
||||
)
|
||||
|
||||
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.console")
|
||||
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key")
|
||||
@patch("codeflash.cli_cmds.console.console")
|
||||
def test_existing_api_key_oserror_treated_as_missing(
|
||||
self, mock_console: MagicMock, mock_get_key: MagicMock
|
||||
) -> None:
|
||||
mock_get_key.side_effect = OSError("permission denied")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
with patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin", return_value=None):
|
||||
with patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin", return_value=None):
|
||||
auth_login()
|
||||
|
||||
@patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="")
|
||||
@patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin")
|
||||
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="")
|
||||
def test_oauth_failure_exits_with_code_1(self, mock_get_key: MagicMock, mock_oauth: MagicMock) -> None:
|
||||
mock_oauth.return_value = None
|
||||
|
||||
|
|
@ -41,10 +41,10 @@ class TestAuthLogin:
|
|||
auth_login()
|
||||
|
||||
@patch("codeflash.cli_cmds.cmd_auth.os")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.save_api_key_to_rc")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.console")
|
||||
@patch("codeflash.code_utils.shell_utils.save_api_key_to_rc")
|
||||
@patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin")
|
||||
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="")
|
||||
@patch("codeflash.cli_cmds.console.console")
|
||||
def test_successful_oauth_saves_key(
|
||||
self,
|
||||
mock_console: MagicMock,
|
||||
|
|
@ -63,10 +63,10 @@ class TestAuthLogin:
|
|||
mock_console.print.assert_called_with("[green]Signed in successfully![/green]")
|
||||
|
||||
@patch("codeflash.cli_cmds.cmd_auth.os")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.save_api_key_to_rc")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="")
|
||||
@patch("codeflash.cli_cmds.cmd_auth.console")
|
||||
@patch("codeflash.code_utils.shell_utils.save_api_key_to_rc")
|
||||
@patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin")
|
||||
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="")
|
||||
@patch("codeflash.cli_cmds.console.console")
|
||||
def test_windows_oauth_saves_key(
|
||||
self,
|
||||
mock_console: MagicMock,
|
||||
|
|
|
|||
Loading…
Reference in a new issue