Merge pull request #2055 from codeflash-ai/perf/defer-cli-imports

perf: defer cli.py imports for 7.7x faster --help
This commit is contained in:
Kevin Turcios 2026-04-10 01:59:57 -05:00 committed by GitHub
commit 72a41a5665
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 606 additions and 110 deletions

View file

@ -199,6 +199,8 @@ jobs:
run: |
uv run ruff check --fix . || true
uv run ruff format .
# uv-dynamic-versioning rewrites version.py on every `uv run` — discard those changes
git checkout HEAD -- codeflash/version.py codeflash-benchmark/codeflash_benchmark/version.py 2>/dev/null || true
- name: Commit and push fixes
run: |

0
benchmarks/__init__.py Normal file
View file

View file

@ -0,0 +1,72 @@
"""Benchmark CLI startup latency for codeflash compare --script mode.
Run from a worktree root. Installs deps via uv sync, then times several
CLI entry points and writes a JSON file mapping command names to median
wall-clock seconds.
Usage:
codeflash compare main codeflash/optimize \
--script "python benchmarks/bench_cli_startup.py" \
--script-output benchmarks/results.json
"""
from __future__ import annotations
import json
import os
import subprocess
import time
from pathlib import Path
WARMUP = 3
RUNS = 30
OUTPUT = os.environ.get("BENCH_OUTPUT", "benchmarks/results.json")
COMMANDS: dict[str, list[str]] = {
"version": ["uv", "run", "codeflash", "--version"],
"help": ["uv", "run", "codeflash", "--help"],
"auth_status": ["uv", "run", "codeflash", "auth", "status"],
"compare_help": ["uv", "run", "codeflash", "compare", "--help"],
}
def measure(cmd: list[str], warmup: int = WARMUP, runs: int = RUNS) -> float:
"""Return median wall-clock seconds for *cmd* over *runs* iterations."""
env = {**os.environ, "CODEFLASH_API_KEY": "bench_dummy_key"}
for _ in range(warmup):
subprocess.run(cmd, capture_output=True, check=False, env=env)
times: list[float] = []
for _ in range(runs):
t0 = time.perf_counter()
subprocess.run(cmd, capture_output=True, check=False, env=env)
times.append(time.perf_counter() - t0)
times.sort()
mid = len(times) // 2
return times[mid] if len(times) % 2 else (times[mid - 1] + times[mid]) / 2
def main() -> None:
# Ensure deps are installed in the worktree
subprocess.run(["uv", "sync"], check=True, capture_output=True)
results: dict[str, float] = {}
for name, cmd in COMMANDS.items():
print(f" {name}: ", end="", flush=True)
median = measure(cmd)
results[name] = round(median, 4)
print(f"{median * 1000:.0f} ms")
# Total = sum of medians (useful for a single summary number)
results["__total__"] = round(sum(results.values()), 4)
output_path = Path(OUTPUT)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w") as f:
json.dump(results, f, indent=2)
print(f"\nResults written to {OUTPUT}")
if __name__ == "__main__":
main()

View file

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional, Union
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.code_utils.formatter import sort_imports
if TYPE_CHECKING:

View file

@ -5,15 +5,6 @@ from argparse import SUPPRESS, ArgumentParser, Namespace
from functools import lru_cache
from pathlib import Path
from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.console import apologize_and_exit, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.languages.test_framework import set_current_test_framework
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version
def parse_args() -> Namespace:
parser = _build_parser()
@ -30,12 +21,17 @@ def parse_args() -> Namespace:
def process_and_validate_cmd_args(args: Namespace) -> Namespace:
from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.git_utils import (
check_running_in_git_repo,
confirm_proceeding_with_no_git_repo,
get_repo_owner_and_name,
)
from codeflash.code_utils.github_utils import require_github_app_or_exit
from codeflash.version import __version__ as version
if args.server:
os.environ["CODEFLASH_AIS_SERVER"] = args.server
@ -85,6 +81,12 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
def process_pyproject_config(args: Namespace) -> Namespace:
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.languages.test_framework import set_current_test_framework
from codeflash.lsp.helpers import is_LSP_enabled
try:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
@ -222,6 +224,9 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
from codeflash.cli_cmds.console import apologize_and_exit, logger
from codeflash.code_utils.code_utils import exit_with_message
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
no_pr = getattr(args, "no_pr", False)

View file

@ -2,17 +2,17 @@ from __future__ import annotations
import os
import click
from codeflash.cli_cmds.console import console
from codeflash.cli_cmds.oauth_handler import perform_oauth_signin
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.either import is_successful
def auth_login() -> None:
"""Perform OAuth login and save the API key."""
import click
from codeflash.cli_cmds.console import console
from codeflash.cli_cmds.oauth_handler import perform_oauth_signin
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.either import is_successful
try:
existing_api_key = get_codeflash_api_key()
except OSError:
@ -41,6 +41,9 @@ def auth_login() -> None:
def auth_status() -> None:
"""Check and display current authentication status."""
from codeflash.cli_cmds.console import console
from codeflash.code_utils.env_utils import get_codeflash_api_key
try:
api_key = get_codeflash_api_key()
except OSError:

View file

@ -0,0 +1,64 @@
"""Cache libcst visitor dispatch table construction.
libcst's ``MatcherDecoratableTransformer`` and
``MatcherDecoratableVisitor`` rebuild visitor dispatch tables on
every instantiation by iterating ``dir(self)`` (~600 attributes)
and calling ``getattr`` + ``inspect.ismethod`` on each. The
results depend only on the class, not the instance, so caching
by ``type(obj)`` is safe.
Import this module before any libcst visitors are instantiated
to install the cache.
"""
from __future__ import annotations
from typing import Any
import libcst.matchers._visitors as _mv
_visit_cache: dict[type, Any] = {}
_leave_cache: dict[type, Any] = {}
_matchers_cache: dict[type, Any] = {}
_original_visit = _mv._gather_constructed_visit_funcs # noqa: SLF001
_original_leave = _mv._gather_constructed_leave_funcs # noqa: SLF001
_original_matchers = _mv._gather_matchers # noqa: SLF001
def _cached_visit(obj: object) -> Any:
"""Return cached visit-function dispatch table for the object's class."""
cls = type(obj)
try:
return _visit_cache[cls]
except KeyError:
result = _original_visit(obj)
_visit_cache[cls] = result
return result
def _cached_leave(obj: object) -> Any:
"""Return cached leave-function dispatch table for the object's class."""
cls = type(obj)
try:
return _leave_cache[cls]
except KeyError:
result = _original_leave(obj)
_leave_cache[cls] = result
return result
def _cached_matchers(obj: object) -> Any:
"""Return cached matcher dispatch table for the object's class."""
cls = type(obj)
try:
return dict(_matchers_cache[cls])
except KeyError:
result = _original_matchers(obj)
_matchers_cache[cls] = result
return dict(result)
_mv._gather_constructed_visit_funcs = _cached_visit # noqa: SLF001
_mv._gather_constructed_leave_funcs = _cached_leave # noqa: SLF001
_mv._gather_matchers = _cached_matchers # noqa: SLF001

View file

@ -9,17 +9,16 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Optional
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
from codeflash.languages.registry import get_language_support_by_common_formatters
from codeflash.lsp.helpers import is_LSP_enabled
def check_formatter_installed(
formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python"
) -> bool:
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import format_code
from codeflash.languages.registry import get_language_support_by_common_formatters
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True
first_cmd = formatter_cmds[0]
@ -69,6 +68,8 @@ def check_formatter_installed(
@lru_cache(maxsize=1)
def get_codeflash_api_key() -> str:
from codeflash.cli_cmds.console import logger
# Check environment variable first
env_api_key = os.environ.get("CODEFLASH_API_KEY")
shell_api_key = read_api_key_from_shell_config()
@ -96,7 +97,8 @@ def get_codeflash_api_key() -> str:
# Prefer the shell configuration over environment variables for lsp,
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
# within the same process, the environment variable could become outdated.
api_key = shell_api_key or env_api_key if is_LSP_enabled() else env_api_key or shell_api_key
is_lsp = os.getenv("CODEFLASH_LSP", default="false").lower() == "true"
api_key = shell_api_key or env_api_key if is_lsp else env_api_key or shell_api_key
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#manual-setup]." # noqa
if not api_key:
@ -106,6 +108,8 @@ def get_codeflash_api_key() -> str:
f"{api_secret_docs_message}"
)
if is_repo_a_fork():
from codeflash.code_utils.code_utils import exit_with_message
msg = (
"Codeflash API key not detected in your environment. It appears you're running Codeflash from a GitHub fork.\n"
"For external contributors, please ensure you've added your own API key to your fork's repository secrets and set it as the CODEFLASH_API_KEY environment variable.\n"
@ -124,6 +128,8 @@ def get_codeflash_api_key() -> str:
def ensure_codeflash_api_key() -> bool:
from codeflash.cli_cmds.console import logger
try:
get_codeflash_api_key()
except OSError:

View file

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.formatter import sort_imports

View file

@ -8,7 +8,6 @@ import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import LF
from codeflash.either import Failure, Success
@ -41,6 +40,8 @@ def is_powershell() -> bool:
2. COMSPEC pointing to powershell.exe
3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell)
"""
from codeflash.cli_cmds.console import logger
if os.name != "nt":
return False
@ -72,6 +73,8 @@ def is_powershell() -> bool:
def read_api_key_from_shell_config() -> Optional[str]:
"""Read API key from shell configuration file."""
from codeflash.cli_cmds.console import logger
shell_rc_path = get_shell_rc_path()
# Ensure shell_rc_path is a Path object for consistent handling
if not isinstance(shell_rc_path, Path):
@ -127,6 +130,8 @@ def get_api_key_export_line(api_key: str) -> str:
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
"""Save API key to the appropriate shell configuration file."""
from codeflash.cli_cmds.console import logger
shell_rc_path = get_shell_rc_path()
# Ensure shell_rc_path is a Path object for consistent handling
if not isinstance(shell_rc_path, Path):

View file

@ -20,6 +20,7 @@ from rich.syntax import Syntax
from rich.text import Text
from rich.tree import Tree
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
from codeflash.benchmarking.utils import process_benchmark_data

View file

@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
from codeflash.code_utils.config_consts import (

View file

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Optional, Union
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.cli_cmds.console import logger
from codeflash.languages import current_language
from codeflash.languages.base import Language

View file

@ -11,6 +11,7 @@ from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
from libcst.helpers import calculate_module_and_package
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW
from codeflash.languages.base import Language

View file

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, TypeVar
import libcst as cst
from libcst.metadata import PositionProvider
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports

View file

@ -10,6 +10,7 @@ import libcst as cst
from libcst import MetadataWrapper
from libcst.metadata import PositionProvider
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList

View file

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Union
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports

View file

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any
import libcst as cst
import codeflash.code_utils._libcst_cache # noqa: F401
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import (
CodeContext,

View file

@ -29,9 +29,36 @@ def main() -> None:
print(f"Codeflash version {__version__}")
return
from codeflash.cli_cmds.cli import parse_args
args = parse_args()
# Auth commands skip banner, telemetry, and version check entirely
if args.command == "auth":
from codeflash.cli_cmds.cmd_auth import auth_login, auth_status
if args.auth_command == "login":
auth_login()
elif args.auth_command == "status":
auth_status()
else:
from codeflash.code_utils.code_utils import exit_with_message
exit_with_message("Usage: codeflash auth {login,status}", error_on_exit=True)
return
# Compare command only needs its own imports
if args.command == "compare":
print_codeflash_banner()
from codeflash.cli_cmds.cmd_compare import run_compare
run_compare(args)
return
# All other commands need the full stack
from pathlib import Path
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.code_utils import env_utils
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
from codeflash.code_utils.config_parser import parse_config_file
@ -39,11 +66,7 @@ def main() -> None:
from codeflash.telemetry import posthog_cf
from codeflash.telemetry.sentry import init_sentry
args = parse_args()
if args.command != "auth":
print_codeflash_banner()
# Check for newer version for all commands
print_codeflash_banner()
check_for_newer_minor_version()
if args.command:
@ -54,18 +77,7 @@ def main() -> None:
init_sentry(enabled=not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not disable_telemetry)
if args.command == "auth":
from codeflash.cli_cmds.cmd_auth import auth_login, auth_status
if args.auth_command == "login":
auth_login()
elif args.auth_command == "status":
auth_status()
else:
from codeflash.code_utils.code_utils import exit_with_message
exit_with_message("Usage: codeflash auth {login,status}", error_on_exit=True)
elif args.command == "init":
if args.command == "init":
from codeflash.cli_cmds.cmd_init import init_codeflash
init_codeflash()
@ -77,10 +89,6 @@ def main() -> None:
from codeflash.cli_cmds.extension import install_vscode_extension
install_vscode_extension()
elif args.command == "compare":
from codeflash.cli_cmds.cmd_compare import run_compare
run_compare(args)
elif args.command == "optimize":
from codeflash.tracer import main as tracer_main

View file

@ -1,37 +1,26 @@
from __future__ import annotations
import enum
import re
import sys
from collections import Counter, defaultdict
from collections.abc import Collection
from enum import Enum, IntEnum
from functools import lru_cache
from typing import TYPE_CHECKING
from pathlib import Path
from re import Pattern
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
import libcst as cst
from rich.tree import Tree
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
from codeflash.languages.registry import get_language_support
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
from codeflash.lsp.lsp_message import LspMarkdownMessage
from codeflash.models.test_type import TestType
if TYPE_CHECKING:
from collections.abc import Iterator
import enum
import re
import sys
from collections.abc import Collection
from enum import Enum, IntEnum
from pathlib import Path
from re import Pattern
from typing import Any, NamedTuple, Optional, cast
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.verification.comparator import comparator
import libcst as cst
from rich.tree import Tree
@dataclass(frozen=True)
@ -254,6 +243,8 @@ class CodeString(BaseModel):
def validate_code_syntax(self) -> CodeString:
"""Validate code syntax for the specified language."""
if self.language == "python":
from codeflash.code_utils.code_utils import validate_python_code
validate_python_code(self.code)
else:
from codeflash.languages.registry import get_language_support
@ -267,6 +258,8 @@ class CodeString(BaseModel):
def get_comment_prefix(file_path: Path) -> str:
"""Get the comment prefix for a given language."""
from codeflash.languages.registry import get_language_support
support = get_language_support(file_path)
return support.comment_prefix
@ -565,6 +558,8 @@ class CandidateEvaluationContext:
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
# Update to shorter code if this candidate has a shorter diff
from codeflash.code_utils.code_utils import diff_length
new_diff_len = diff_length(candidate.source_code.flat, original_flat_code)
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
@ -574,6 +569,8 @@ class CandidateEvaluationContext:
self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str
) -> None:
"""Register a new candidate that hasn't been seen before."""
from codeflash.code_utils.code_utils import diff_length
self.ast_code_to_id[normalized_code] = {
"optimization_id": candidate.optimization_id,
"shorter_source_code": candidate.source_code,
@ -670,6 +667,9 @@ class CoverageData:
def log_coverage(self) -> None:
from rich.tree import Tree
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import is_end_to_end
tree = Tree("Test Coverage Results")
tree.add(f"Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%")
if self.dependent_func_coverage:
@ -769,12 +769,16 @@ class InvocationId:
)
def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]:
import libcst as cst
for stmt in class_node.body.body:
if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name:
return stmt
return None
def get_src_code(self, test_path: Path) -> Optional[str]:
import libcst as cst
if not test_path.exists():
return None
try:
@ -856,6 +860,8 @@ class TestResults(BaseModel): # noqa: PLW1641
unique_id = function_test_invocation.unique_invocation_loop_id
test_result_idx = self.test_result_idx
if unique_id in test_result_idx:
from codeflash.cli_cmds.console import DEBUG_MODE, logger
if DEBUG_MODE:
logger.warning(f"Test result with id {unique_id} already exists. SKIPPING")
return
@ -876,6 +882,8 @@ class TestResults(BaseModel): # noqa: PLW1641
self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path
) -> dict[BenchmarkKey, TestResults]:
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
from codeflash.code_utils.code_utils import module_name_from_file_path
test_results_by_benchmark = defaultdict(TestResults)
benchmark_module_path = {}
for benchmark_key in benchmark_keys:
@ -929,9 +937,17 @@ class TestResults(BaseModel): # noqa: PLW1641
@staticmethod
def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
from rich.tree import Tree
from codeflash.lsp.helpers import is_LSP_enabled
tree = Tree(title)
if is_LSP_enabled():
from codeflash.cli_cmds.console import lsp_log
from codeflash.lsp.helpers import report_to_markdown_table
from codeflash.lsp.lsp_message import LspMarkdownMessage
# Build markdown table
markdown = report_to_markdown_table(report, title)
lsp_log(LspMarkdownMessage(markdown=markdown))
@ -946,6 +962,8 @@ class TestResults(BaseModel): # noqa: PLW1641
return tree
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
from codeflash.cli_cmds.console import logger
# Efficient single traversal, directly accumulating into a dict.
# can track mins here and only sums can be return in total_passed_runtime
by_id: dict[InvocationId, list[int]] = {}
@ -1025,6 +1043,8 @@ class TestResults(BaseModel): # noqa: PLW1641
return bool(self.test_results)
def __eq__(self, other: object) -> bool:
from codeflash.verification.comparator import comparator
# Unordered comparison
if type(self) is not type(other):
return False

View file

@ -74,6 +74,27 @@ _DICT_KEYS_TYPE = type({}.keys())
_DICT_VALUES_TYPE = type({}.values())
_DICT_ITEMS_TYPE = type({}.items())
_IDENTITY_EQ_TYPES: frozenset[type[Any]] = frozenset(
{
int,
bool,
complex,
type(None),
type(Ellipsis),
decimal.Decimal,
set,
bytes,
bytearray,
memoryview,
frozenset,
type,
range,
slice,
OrderedDict,
types.GenericAlias,
}
)
_EQUALITY_TYPES = (
int,
bool,
@ -184,32 +205,61 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return False
if type(orig) is not type(new):
type_obj = type(orig)
new_type_obj = type(new)
orig_type = type(orig)
if orig_type is not type(new):
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
if orig_type.__name__ != type(new).__name__ or orig_type.__qualname__ != type(new).__qualname__:
return False
# Fast-path: type identity checks for the most common return-value types.
# `orig_type is T` is a single pointer comparison — cheaper than frozenset hash
# lookup or isinstance MRO traversal — and these 4 types dominate real workloads.
if orig_type is str:
if orig == new:
return True
if _is_temp_path(orig) and _is_temp_path(new):
return _normalize_temp_path(orig) == _normalize_temp_path(new)
return False
if orig_type is list or orig_type is tuple:
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
if orig_type is dict:
if superset_obj:
return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items())
if len(orig) != len(new):
return False
for key in orig:
if key not in new:
return False
if not comparator(orig[key], new[key], superset_obj):
return False
return True
if orig_type is float:
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
# O(1) frozenset lookup for remaining common types (int, bool, None, Decimal, etc.)
if orig_type in _IDENTITY_EQ_TYPES:
return orig == new
# Slower isinstance path for subclasses (deque, ChainMap, etc.)
if isinstance(orig, (list, tuple, deque, ChainMap)):
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
# Handle strings separately to normalize temp paths
# Handle string subclasses separately to normalize temp paths
if isinstance(orig, str):
if orig == new:
return True
# If strings differ, check if they're temp paths that differ only in session number
if _is_temp_path(orig) and _is_temp_path(new):
return _normalize_temp_path(orig) == _normalize_temp_path(new)
return False
# enum.Enum subclasses and UnionType fall through from the frozenset fast-path
if isinstance(orig, _EQUALITY_TYPES):
return orig == new
if isinstance(orig, float):
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
if isinstance(orig, weakref.ref):

View file

View file

@ -1,31 +1,18 @@
from argparse import Namespace
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
def test_benchmark_extract(benchmark) -> None:
file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
opt = Optimizer(
Namespace(
project_root=file_path.resolve(),
disable_telemetry=True,
tests_root=(file_path / "tests").resolve(),
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path.cwd(),
)
)
project_root = Path(__file__).parent.parent.parent.resolve() / "codeflash"
function_to_optimize = FunctionToOptimize(
function_name="replace_function_and_helpers_with_optimized_code",
file_path=file_path / "languages" / "function_optimizer.py",
file_path=project_root / "languages" / "function_optimizer.py",
parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")],
starting_line=None,
ending_line=None,
)
benchmark(get_code_optimization_context, function_to_optimize, opt.args.project_root)
benchmark(get_code_optimization_context, function_to_optimize, project_root)

View file

@ -0,0 +1,133 @@
"""Benchmark comparator type dispatch performance.
Exercises the fast-path frozenset lookup vs isinstance MRO traversal
across realistic return value shapes: primitives, nested containers,
and mixed-type structures typical of real optimization verification.
"""
from __future__ import annotations
from collections import OrderedDict
from decimal import Decimal
from codeflash.verification.comparator import comparator
# --- Test data: realistic return value shapes ---
# 1. Flat primitives (int, bool, None, str, float, bytes) — the fast-path sweet spot
_PRIMITIVES_A = [
42,
True,
None,
3.14,
"hello",
b"bytes",
0,
False,
"",
1.0,
-1,
None,
True,
99,
"world",
b"\x00\x01",
2**31,
0.0,
False,
None,
]
_PRIMITIVES_B = list(_PRIMITIVES_A)
# 2. Nested dict of lists (common return value shape: API responses, parsed configs)
_NESTED_DICT_A = {
"users": [{"id": i, "name": f"user_{i}", "active": i % 2 == 0, "score": i * 1.5} for i in range(50)],
"metadata": {"total": 50, "page": 1, "has_next": True},
"tags": [f"tag_{i}" for i in range(20)],
"config": {"timeout": 30, "retries": 3, "debug": False, "threshold": Decimal("0.95")},
}
_NESTED_DICT_B = {
"users": [{"id": i, "name": f"user_{i}", "active": i % 2 == 0, "score": i * 1.5} for i in range(50)],
"metadata": {"total": 50, "page": 1, "has_next": True},
"tags": [f"tag_{i}" for i in range(20)],
"config": {"timeout": 30, "retries": 3, "debug": False, "threshold": Decimal("0.95")},
}
# 3. List of tuples (common: database rows, CSV data)
_ROWS_A = [(i, f"row_{i}", i * 0.1, i % 3 == 0, None if i % 5 == 0 else i) for i in range(200)]
_ROWS_B = [(i, f"row_{i}", i * 0.1, i % 3 == 0, None if i % 5 == 0 else i) for i in range(200)]
# 4. Deeply nested structure (worst case for recursive comparator)
def _make_deep(depth: int) -> dict:
if depth == 0:
return {"leaf": True, "value": 42, "items": [1, 2, 3], "label": "end"}
return {"level": depth, "child": _make_deep(depth - 1), "siblings": list(range(depth))}
_DEEP_A = _make_deep(15)
_DEEP_B = _make_deep(15)
# 5. Mixed identity types (frozenset, range, slice, OrderedDict, bytes, complex)
_IDENTITY_TYPES_A = [
frozenset({1, 2, 3}),
range(100),
complex(1, 2),
Decimal("3.14"),
OrderedDict(a=1, b=2),
b"binary",
bytearray(b"mutable"),
memoryview(b"view"),
type(None),
True,
42,
None,
] * 10
_IDENTITY_TYPES_B = list(_IDENTITY_TYPES_A)
def _compare_all_primitives() -> None:
for a, b in zip(_PRIMITIVES_A, _PRIMITIVES_B):
comparator(a, b)
def _compare_nested_dict() -> None:
comparator(_NESTED_DICT_A, _NESTED_DICT_B)
def _compare_rows() -> None:
comparator(_ROWS_A, _ROWS_B)
def _compare_deep() -> None:
comparator(_DEEP_A, _DEEP_B)
def _compare_identity_types() -> None:
for a, b in zip(_IDENTITY_TYPES_A, _IDENTITY_TYPES_B):
comparator(a, b)
def test_benchmark_comparator_primitives(benchmark) -> None:
"""20 flat primitive comparisons (int, bool, None, str, float, bytes)."""
benchmark(_compare_all_primitives)
def test_benchmark_comparator_nested_dict(benchmark) -> None:
"""Nested dict with 50-element user list, metadata, tags, config."""
benchmark(_compare_nested_dict)
def test_benchmark_comparator_rows(benchmark) -> None:
"""200 tuples of (int, str, float, bool, Optional[int])."""
benchmark(_compare_rows)
def test_benchmark_comparator_deep(benchmark) -> None:
"""15-level deep nested dict structure."""
benchmark(_compare_deep)
def test_benchmark_comparator_identity_types(benchmark) -> None:
"""120 frozenset/range/complex/Decimal/OrderedDict/bytes comparisons."""
benchmark(_compare_identity_types)

View file

@ -0,0 +1,75 @@
"""Benchmark libcst visitor performance across many files.
Exercises the visitor-heavy codepaths that benefit from the libcst dispatch
table cache: discover_functions + get_code_optimization_context on multiple
real source files.
"""
from __future__ import annotations
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
from codeflash.languages.python.support import PythonSupport
from codeflash.models.models import FunctionParent
# Real source files from the codeflash codebase, chosen for size and visitor diversity.
_CODEFLASH_ROOT = Path(__file__).parent.parent.parent.resolve() / "codeflash"
_SOURCE_FILES: list[Path] = [
_CODEFLASH_ROOT / "languages" / "function_optimizer.py",
_CODEFLASH_ROOT / "languages" / "python" / "context" / "code_context_extractor.py",
_CODEFLASH_ROOT / "languages" / "python" / "support.py",
_CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_extractor.py",
_CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_replacer.py",
_CODEFLASH_ROOT / "code_utils" / "instrument_existing_tests.py",
_CODEFLASH_ROOT / "benchmarking" / "compare.py",
_CODEFLASH_ROOT / "models" / "models.py",
_CODEFLASH_ROOT / "discovery" / "discover_unit_tests.py",
_CODEFLASH_ROOT / "languages" / "base.py",
]
# For each file, pick one top-level function to extract context for.
# (class, function_name) — class=None means module-level.
_TARGETS: list[tuple[Path, str | None, str]] = [
(_SOURCE_FILES[0], "FunctionOptimizer", "replace_function_and_helpers_with_optimized_code"),
(_SOURCE_FILES[1], None, "get_code_optimization_context"),
(_SOURCE_FILES[2], "PythonSupport", "discover_functions"),
(_SOURCE_FILES[3], None, "add_global_assignments"),
(_SOURCE_FILES[4], None, "replace_functions_in_file"),
(_SOURCE_FILES[5], None, "inject_profiling_into_existing_test"),
(_SOURCE_FILES[6], None, "compare_branches"),
(_SOURCE_FILES[7], None, "get_comment_prefix"),
(_SOURCE_FILES[8], None, "discover_unit_tests"),
(_SOURCE_FILES[9], None, "convert_parents_to_tuple"),
]
def _discover_all() -> None:
"""Run discover_functions on all source files."""
ps = PythonSupport()
for file_path in _SOURCE_FILES:
source = file_path.read_text(encoding="utf-8")
ps.discover_functions(source=source, file_path=file_path)
def _extract_all_contexts() -> None:
"""Run get_code_optimization_context on every target function."""
project_root = _CODEFLASH_ROOT.parent
for file_path, class_name, func_name in _TARGETS:
parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else []
fto = FunctionToOptimize(
function_name=func_name, file_path=file_path, parents=parents, starting_line=None, ending_line=None
)
get_code_optimization_context(fto, project_root)
def test_benchmark_discover_functions_multi_file(benchmark) -> None:
"""Discover functions across 10 source files."""
benchmark(_discover_all)
def test_benchmark_extract_context_multi_file(benchmark) -> None:
"""Extract code optimization context for 10 functions across 10 files."""
benchmark(_extract_all_contexts)

View file

@ -0,0 +1,56 @@
"""Benchmark the full libcst-heavy pipeline on a single file.
Runs discover extract context replace functions add global assignments
in sequence, exercising ~15 distinct visitor/transformer classes in one pass.
"""
from __future__ import annotations
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
from codeflash.languages.python.static_analysis.code_extractor import add_global_assignments
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file
from codeflash.languages.python.support import PythonSupport
_CODEFLASH_ROOT = Path(__file__).parent.parent.parent.resolve() / "codeflash"
_PROJECT_ROOT = _CODEFLASH_ROOT.parent
# Target: a real, non-trivial file with classes and module-level functions.
_TARGET_FILE = _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_extractor.py"
_TARGET_FUNC = "add_global_assignments"
# A second file to serve as "optimized" source for replace/merge steps.
_SECOND_FILE = _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_replacer.py"
def _run_pipeline() -> None:
"""Simulate a single-file optimization pass through the full visitor pipeline."""
source = _TARGET_FILE.read_text(encoding="utf-8")
source2 = _SECOND_FILE.read_text(encoding="utf-8")
# 1. Discover functions (FunctionVisitor + MetadataWrapper)
ps = PythonSupport()
functions = ps.discover_functions(source=source, file_path=_TARGET_FILE)
# 2. Extract code optimization context (multiple collectors + dependency resolver)
fto = FunctionToOptimize(
function_name=_TARGET_FUNC, file_path=_TARGET_FILE, parents=[], starting_line=None, ending_line=None
)
get_code_optimization_context(fto, _PROJECT_ROOT)
# 3. Replace functions (GlobalFunctionCollector + GlobalFunctionTransformer)
# Use a class method from discovered functions if available, else module-level.
func_names = [_TARGET_FUNC]
replace_functions_in_file(
source_code=source, original_function_names=func_names, optimized_code=source2, preexisting_objects=set()
)
# 4. Add global assignments (6 visitors/transformers)
add_global_assignments(source2, source)
def test_benchmark_full_pipeline(benchmark) -> None:
"""Full discover → extract → replace → merge pipeline on one file."""
benchmark(_run_pipeline)

View file

@ -9,8 +9,8 @@ from codeflash.either import Success
class TestAuthLogin:
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key")
@patch("codeflash.cli_cmds.cmd_auth.console")
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key")
@patch("codeflash.cli_cmds.console.console")
def test_existing_api_key_skips_oauth(self, mock_console: MagicMock, mock_get_key: MagicMock) -> None:
mock_get_key.return_value = "cf-test1234abcd"
@ -21,19 +21,19 @@ class TestAuthLogin:
"To re-authenticate, unset [bold]CODEFLASH_API_KEY[/bold] and run this command again."
)
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key")
@patch("codeflash.cli_cmds.cmd_auth.console")
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key")
@patch("codeflash.cli_cmds.console.console")
def test_existing_api_key_oserror_treated_as_missing(
self, mock_console: MagicMock, mock_get_key: MagicMock
) -> None:
mock_get_key.side_effect = OSError("permission denied")
with pytest.raises(SystemExit):
with patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin", return_value=None):
with patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin", return_value=None):
auth_login()
@patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin")
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="")
@patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin")
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="")
def test_oauth_failure_exits_with_code_1(self, mock_get_key: MagicMock, mock_oauth: MagicMock) -> None:
mock_oauth.return_value = None
@ -41,10 +41,10 @@ class TestAuthLogin:
auth_login()
@patch("codeflash.cli_cmds.cmd_auth.os")
@patch("codeflash.cli_cmds.cmd_auth.save_api_key_to_rc")
@patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin")
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="")
@patch("codeflash.cli_cmds.cmd_auth.console")
@patch("codeflash.code_utils.shell_utils.save_api_key_to_rc")
@patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin")
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="")
@patch("codeflash.cli_cmds.console.console")
def test_successful_oauth_saves_key(
self,
mock_console: MagicMock,
@ -63,10 +63,10 @@ class TestAuthLogin:
mock_console.print.assert_called_with("[green]Signed in successfully![/green]")
@patch("codeflash.cli_cmds.cmd_auth.os")
@patch("codeflash.cli_cmds.cmd_auth.save_api_key_to_rc")
@patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin")
@patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="")
@patch("codeflash.cli_cmds.cmd_auth.console")
@patch("codeflash.code_utils.shell_utils.save_api_key_to_rc")
@patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin")
@patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="")
@patch("codeflash.cli_cmds.console.console")
def test_windows_oauth_saves_key(
self,
mock_console: MagicMock,