Merge pull request #1915 from codeflash-ai/cf-remove-codeflash-core
chore: remove src/codeflash_core package
This commit is contained in:
commit
1c5404f156
26 changed files with 1 additions and 2905 deletions
|
|
@ -104,7 +104,7 @@ tests = [
|
|||
]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["codeflash", "src/codeflash_core"]
|
||||
include = ["codeflash"]
|
||||
exclude = [
|
||||
"docs/*",
|
||||
"experiments/*",
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
from codeflash_core.config import CoreConfig, TestConfig
|
||||
from codeflash_core.models import (
|
||||
BenchmarkResults,
|
||||
Candidate,
|
||||
CodeContext,
|
||||
FunctionToOptimize,
|
||||
HelperFunction,
|
||||
OptimizationResult,
|
||||
ScoredCandidate,
|
||||
TestOutcome,
|
||||
TestOutcomeStatus,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash_core.optimizer import Optimizer
|
||||
from codeflash_core.protocols import LanguagePlugin
|
||||
|
||||
__all__ = [
|
||||
"BenchmarkResults",
|
||||
"Candidate",
|
||||
"CodeContext",
|
||||
"CoreConfig",
|
||||
"FunctionToOptimize",
|
||||
"HelperFunction",
|
||||
"LanguagePlugin",
|
||||
"OptimizationResult",
|
||||
"Optimizer",
|
||||
"ScoredCandidate",
|
||||
"TestConfig",
|
||||
"TestOutcome",
|
||||
"TestOutcomeStatus",
|
||||
"TestResults",
|
||||
]
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from codeflash_core.ai.client import AIClient
|
||||
|
||||
__all__ = ["AIClient"]
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_core.config import AIConfig
|
||||
from codeflash_core.models import Candidate, CodeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""Client for the Codeflash AI optimization service."""
|
||||
|
||||
def __init__(self, config: AIConfig) -> None:
|
||||
self.base_url = config.base_url.rstrip("/")
|
||||
self.api_key = config.api_key
|
||||
self.timeout = config.timeout
|
||||
self.session = requests.Session()
|
||||
if self.api_key:
|
||||
self.session.headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
def get_candidates(self, context: CodeContext) -> list[Candidate]:
|
||||
"""Request optimization candidates from the AI service."""
|
||||
from codeflash_core.models import Candidate
|
||||
|
||||
payload = {
|
||||
"function_name": context.target_function.qualified_name,
|
||||
"source_code": context.target_code,
|
||||
"helper_functions": [
|
||||
{"name": h.qualified_name, "source_code": h.source_code} for h in context.helper_functions
|
||||
],
|
||||
"read_only_context": context.read_only_context,
|
||||
"imports": context.imports,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = self.session.post(f"{self.base_url}/ai/optimize", json=payload, timeout=self.timeout)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except requests.RequestException:
|
||||
logger.exception("AI service request failed")
|
||||
return []
|
||||
|
||||
candidates = []
|
||||
for item in data.get("candidates", []):
|
||||
candidates.append(Candidate(code=item["code"], explanation=item.get("explanation", "")))
|
||||
return candidates
|
||||
|
||||
def close(self) -> None:
|
||||
self.session.close()
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
"""CLI entry point for codeflash."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_core.strategy_utils import OptimizationStrategy
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="cfnext", description="Optimize Python code with AI.")
|
||||
parser.add_argument("--version", action="store_true", help="Print the version and exit.")
|
||||
parser.add_argument(
|
||||
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
opt = sub.add_parser("optimize", help="Optimize functions in the given files.")
|
||||
opt.add_argument("files", nargs="*", help="Python files to optimize.")
|
||||
opt.add_argument("--file", dest="target_file", default=None, help="Single file to optimize.")
|
||||
opt.add_argument("--function", dest="target_function", default=None, help="Specific function name to optimize.")
|
||||
opt.add_argument("--all", action="store_true", dest="optimize_all", help="Optimize all files in module-root.")
|
||||
opt.add_argument("--project-root", type=Path, default=None, help="Override project root directory.")
|
||||
opt.add_argument("--module-root", default=None, help="Override module root (relative to project root).")
|
||||
opt.add_argument("--tests-root", default=None, help="Override tests root (relative to project root).")
|
||||
opt.add_argument("--benchmarks-root", default=None, help="Override benchmarks root (relative to project root).")
|
||||
opt.add_argument("--api-key", default=None, help="Codeflash API key (default: $CODEFLASH_API_KEY).")
|
||||
opt.add_argument("--effort", choices=["low", "medium", "high"], default=None, help="Effort level for optimization.")
|
||||
opt.add_argument(
|
||||
"--server",
|
||||
choices=["local", "prod"],
|
||||
default=None,
|
||||
help="AI service server: 'local' for localhost:8000, 'prod' for app.codeflash.ai.",
|
||||
)
|
||||
opt.add_argument(
|
||||
"--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact."
|
||||
)
|
||||
opt.add_argument("--no-pr", action="store_true", help="Do not create a PR, only update code locally.")
|
||||
opt.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts.")
|
||||
opt.add_argument(
|
||||
"--formatter-cmds", nargs="*", default=None, help="Override formatter commands (each applied to $file)."
|
||||
)
|
||||
opt.add_argument("--pytest-cmd", default=None, help="Override the pytest command to use.")
|
||||
opt.add_argument("--disable-telemetry", action="store_true", help="Disable telemetry.")
|
||||
opt.add_argument(
|
||||
"--strategy", choices=["default"], default="default", help="Optimization strategy to use (default: default)."
|
||||
)
|
||||
opt.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def collect_files(config_project_root: Path, config_module_root: str) -> list[Path]:
|
||||
"""Collect all .py files under the module root."""
|
||||
module_dir = config_project_root / config_module_root if config_module_root else config_project_root
|
||||
if not module_dir.is_dir():
|
||||
return []
|
||||
return sorted(
|
||||
p for p in module_dir.rglob("*.py") if not p.name.startswith("test_") and not p.name.endswith("_test.py")
|
||||
)
|
||||
|
||||
|
||||
SERVER_URLS = {"local": "http://localhost:8000", "prod": "https://app.codeflash.ai"}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# -- Top-level flags (no subcommand needed) ------------------------------
|
||||
if args.version:
|
||||
from importlib.metadata import version
|
||||
|
||||
print(f"cfnext {version('codeflash-core')}")
|
||||
return 0
|
||||
|
||||
# -- Config (needed for --show-config and optimize) ----------------------
|
||||
from codeflash_core.config import CoreConfig
|
||||
|
||||
start_dir = getattr(args, "project_root", None) or Path.cwd()
|
||||
config = CoreConfig.find_and_load(start_dir)
|
||||
|
||||
if args.show_config:
|
||||
import json
|
||||
|
||||
info = {
|
||||
"project_root": str(config.project_root),
|
||||
"module_root": config.module_root,
|
||||
"tests_root": config.tests_root,
|
||||
"benchmarks_root": config.benchmarks_root,
|
||||
"effort": config.effort,
|
||||
"formatter_cmds": config.formatter_cmds,
|
||||
"ignore_paths": config.ignore_paths,
|
||||
"ai_base_url": config.ai.base_url,
|
||||
"disable_telemetry": config.disable_telemetry,
|
||||
}
|
||||
print(json.dumps(info, indent=2))
|
||||
return 0
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
if args.command != "optimize":
|
||||
parser.print_help()
|
||||
return 1
|
||||
|
||||
# -- Logging -------------------------------------------------------------
|
||||
from codeflash_core.ui import setup_logging
|
||||
|
||||
setup_logging(level=logging.DEBUG if args.verbose else logging.WARNING)
|
||||
|
||||
# CLI overrides
|
||||
if args.project_root:
|
||||
config.project_root = args.project_root.resolve()
|
||||
if args.module_root is not None:
|
||||
config.module_root = args.module_root
|
||||
if args.tests_root is not None:
|
||||
config.tests_root = args.tests_root
|
||||
if args.benchmarks_root is not None:
|
||||
config.benchmarks_root = args.benchmarks_root
|
||||
if args.effort is not None:
|
||||
config.effort = args.effort
|
||||
if args.server is not None:
|
||||
config.ai.base_url = SERVER_URLS[args.server]
|
||||
if args.formatter_cmds is not None:
|
||||
config.formatter_cmds = args.formatter_cmds
|
||||
if args.disable_telemetry:
|
||||
config.disable_telemetry = True
|
||||
|
||||
# API key: CLI flag > env var > config file
|
||||
api_key = args.api_key or os.environ.get("CODEFLASH_API_KEY", "") or config.ai.api_key
|
||||
if not api_key:
|
||||
print("Error: No API key provided. Set CODEFLASH_API_KEY or pass --api-key.")
|
||||
return 1
|
||||
config.ai.api_key = api_key
|
||||
|
||||
# -- Telemetry -----------------------------------------------------------
|
||||
from codeflash_core.telemetry import PostHogClient, init_sentry
|
||||
|
||||
if not config.disable_telemetry:
|
||||
PostHogClient.initialize(config.telemetry.posthog_api_key, enabled=config.telemetry.enabled)
|
||||
init_sentry(config.telemetry.sentry_dsn, enabled=config.telemetry.enabled)
|
||||
|
||||
# -- Resolve files -------------------------------------------------------
|
||||
if args.target_file:
|
||||
files = [Path(args.target_file).resolve()]
|
||||
elif args.optimize_all:
|
||||
files = collect_files(config.project_root, config.module_root)
|
||||
elif args.files:
|
||||
files = [Path(f).resolve() for f in args.files]
|
||||
else:
|
||||
print("Error: Provide files to optimize, use --file, or use --all.")
|
||||
return 1
|
||||
|
||||
if not files:
|
||||
print("Error: No Python files found.")
|
||||
return 1
|
||||
|
||||
# -- Build plugin & validate environment ---------------------------------
|
||||
from codeflash_core.optimizer import Optimizer
|
||||
|
||||
try:
|
||||
from codeflash.plugin import PythonPlugin
|
||||
except ImportError:
|
||||
print("Error: codeflash package not installed. Install it to use the Python plugin.")
|
||||
return 1
|
||||
|
||||
plugin = PythonPlugin(config.project_root)
|
||||
|
||||
if hasattr(plugin, "validate_environment") and not plugin.validate_environment(config):
|
||||
return 1
|
||||
|
||||
# -- Resolve strategy ----------------------------------------------------
|
||||
strategy = _resolve_strategy(args.strategy)
|
||||
|
||||
optimizer = Optimizer(config, plugin, strategy=strategy)
|
||||
results = optimizer.run(files, function_filter=args.target_function)
|
||||
|
||||
# -- Shutdown telemetry --------------------------------------------------
|
||||
if PostHogClient.instance is not None:
|
||||
PostHogClient.instance.shutdown()
|
||||
|
||||
return 0 if results else 2
|
||||
|
||||
|
||||
def _resolve_strategy(name: str) -> OptimizationStrategy:
|
||||
"""Return the strategy instance for the given CLI name."""
|
||||
from codeflash_core.strategy import DefaultStrategy
|
||||
|
||||
return DefaultStrategy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -1,239 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tomlkit
|
||||
|
||||
|
||||
class EffortLevel(str, Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
class EffortKeys(str, Enum):
|
||||
N_OPTIMIZER_CANDIDATES = "N_OPTIMIZER_CANDIDATES"
|
||||
N_OPTIMIZER_LP_CANDIDATES = "N_OPTIMIZER_LP_CANDIDATES"
|
||||
N_GENERATED_TESTS = "N_GENERATED_TESTS"
|
||||
MAX_CODE_REPAIRS_PER_TRACE = "MAX_CODE_REPAIRS_PER_TRACE"
|
||||
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = "REPAIR_UNMATCHED_PERCENTAGE_LIMIT"
|
||||
TOP_VALID_CANDIDATES_FOR_REFINEMENT = "TOP_VALID_CANDIDATES_FOR_REFINEMENT"
|
||||
ADAPTIVE_OPTIMIZATION_THRESHOLD = "ADAPTIVE_OPTIMIZATION_THRESHOLD"
|
||||
MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE = "MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE"
|
||||
|
||||
|
||||
HIGH_EFFORT_TOP_N = 15
|
||||
MAX_TEST_REPAIR_CYCLES = 2
|
||||
|
||||
EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = {
|
||||
EffortKeys.N_OPTIMIZER_CANDIDATES.value: {EffortLevel.LOW: 3, EffortLevel.MEDIUM: 5, EffortLevel.HIGH: 6},
|
||||
EffortKeys.N_OPTIMIZER_LP_CANDIDATES.value: {EffortLevel.LOW: 4, EffortLevel.MEDIUM: 6, EffortLevel.HIGH: 7},
|
||||
EffortKeys.N_GENERATED_TESTS.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 2, EffortLevel.HIGH: 2},
|
||||
EffortKeys.MAX_CODE_REPAIRS_PER_TRACE.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 3, EffortLevel.HIGH: 5},
|
||||
EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT.value: {
|
||||
EffortLevel.LOW: 0.2,
|
||||
EffortLevel.MEDIUM: 0.3,
|
||||
EffortLevel.HIGH: 0.4,
|
||||
},
|
||||
EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT.value: {
|
||||
EffortLevel.LOW: 2,
|
||||
EffortLevel.MEDIUM: 3,
|
||||
EffortLevel.HIGH: 4,
|
||||
},
|
||||
EffortKeys.ADAPTIVE_OPTIMIZATION_THRESHOLD.value: {EffortLevel.LOW: 0, EffortLevel.MEDIUM: 0, EffortLevel.HIGH: 2},
|
||||
EffortKeys.MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE.value: {
|
||||
EffortLevel.LOW: 0,
|
||||
EffortLevel.MEDIUM: 0,
|
||||
EffortLevel.HIGH: 4,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_effort_value(key: EffortKeys, effort: EffortLevel | str) -> Any:
|
||||
"""Look up an effort-dependent parameter value."""
|
||||
if isinstance(effort, str):
|
||||
effort = EffortLevel(effort)
|
||||
return EFFORT_VALUES[key.value][effort]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
tests_root: Path
|
||||
project_root: Path
|
||||
test_command: str = ""
|
||||
timeout: float = 60.0
|
||||
tests_project_rootdir: Path | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TelemetryConfig:
|
||||
enabled: bool = True
|
||||
posthog_api_key: str = ""
|
||||
sentry_dsn: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
base_url: str = "https://app.codeflash.ai"
|
||||
api_key: str = ""
|
||||
timeout: float = 120.0
|
||||
|
||||
|
||||
CONFIG_FILE_NAMES = ("pyproject.toml", "codeflash.toml")
|
||||
|
||||
GLOB_PATTERN_CHARS = frozenset("*?[")
|
||||
|
||||
|
||||
def is_glob_pattern(path_str: str) -> bool:
|
||||
"""Check if a path string contains glob pattern characters."""
|
||||
return any(char in path_str for char in GLOB_PATTERN_CHARS)
|
||||
|
||||
|
||||
def normalize_ignore_paths(paths: list[str], base_path: Path | None = None) -> list[Path]:
|
||||
if base_path is None:
|
||||
base_path = Path.cwd()
|
||||
|
||||
base_path = base_path.resolve()
|
||||
normalized: set[Path] = set()
|
||||
|
||||
for path_str in paths:
|
||||
if not path_str:
|
||||
continue
|
||||
|
||||
path_str = str(path_str)
|
||||
|
||||
if is_glob_pattern(path_str):
|
||||
path_str = path_str.removeprefix("./")
|
||||
if path_str.startswith("/"):
|
||||
path_str = path_str.lstrip("/")
|
||||
for matched_path in base_path.glob(path_str):
|
||||
normalized.add(matched_path.resolve())
|
||||
else:
|
||||
path_obj = Path(path_str)
|
||||
if not path_obj.is_absolute():
|
||||
path_obj = base_path / path_obj
|
||||
if path_obj.exists():
|
||||
normalized.add(path_obj.resolve())
|
||||
|
||||
return list(normalized)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoreConfig:
|
||||
project_root: Path = field(default_factory=Path.cwd)
|
||||
module_root: str = ""
|
||||
tests_root: str = "tests"
|
||||
benchmarks_root: str = ""
|
||||
ignore_paths: list[str] = field(default_factory=list)
|
||||
formatter_cmds: list[str] = field(default_factory=list)
|
||||
disable_telemetry: bool = False
|
||||
effort: str = "medium"
|
||||
create_pr: bool = False
|
||||
pytest_cmd: str = "pytest"
|
||||
disable_imports_sorting: bool = False
|
||||
git_remote: str = "origin"
|
||||
override_fixtures: bool = False
|
||||
|
||||
ai: AIConfig = field(default_factory=AIConfig)
|
||||
telemetry: TelemetryConfig = field(default_factory=TelemetryConfig)
|
||||
|
||||
@property
|
||||
def resolved_ignore_paths(self) -> list[Path]:
|
||||
base = (self.project_root / self.module_root) if self.module_root else self.project_root
|
||||
return normalize_ignore_paths(self.ignore_paths, base_path=base)
|
||||
|
||||
@classmethod
|
||||
def from_toml(cls, path: Path) -> CoreConfig:
|
||||
"""Load config from a toml file containing a [tool.codeflash] section.
|
||||
|
||||
Works with both pyproject.toml and codeflash.toml.
|
||||
"""
|
||||
with path.open(encoding="utf-8") as f:
|
||||
data = tomlkit.load(f)
|
||||
|
||||
cf: dict[str, Any] = data.get("tool", {}).get("codeflash", {})
|
||||
if not cf:
|
||||
return cls(project_root=path.parent)
|
||||
|
||||
config = cls(
|
||||
project_root=path.parent,
|
||||
module_root=cf.get("module-root", ""),
|
||||
tests_root=cf.get("tests-root", "tests"),
|
||||
benchmarks_root=cf.get("benchmarks-root", ""),
|
||||
ignore_paths=cf.get("ignore-paths", []),
|
||||
formatter_cmds=cf.get("formatter-cmds", []),
|
||||
disable_telemetry=cf.get("disable_telemetry", False),
|
||||
effort=cf.get("effort", "medium"),
|
||||
create_pr=cf.get("create-pr", False),
|
||||
pytest_cmd=cf.get("pytest-cmd", "pytest"),
|
||||
disable_imports_sorting=cf.get("disable-imports-sorting", False),
|
||||
git_remote=cf.get("git-remote", "origin"),
|
||||
override_fixtures=cf.get("override-fixtures", False),
|
||||
)
|
||||
|
||||
if config.module_root and not (config.project_root / config.module_root).exists():
|
||||
msg = f"module-root '{config.module_root}' does not exist under {config.project_root}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def find_and_load(cls, start: Path | None = None) -> CoreConfig:
|
||||
"""Walk up from start (default cwd) looking for a config file.
|
||||
|
||||
Searches for pyproject.toml and codeflash.toml in each directory
|
||||
up to the filesystem root. Returns a default config if nothing is found.
|
||||
"""
|
||||
path = cls.find_config_file(start or Path.cwd())
|
||||
if path is None:
|
||||
return cls(project_root=start or Path.cwd())
|
||||
return cls.from_toml(path)
|
||||
|
||||
@staticmethod
|
||||
def find_config_file(start: Path) -> Path | None:
|
||||
"""Walk up directories looking for a config file with a [tool.codeflash] section."""
|
||||
current = start.resolve()
|
||||
while True:
|
||||
for name in CONFIG_FILE_NAMES:
|
||||
candidate = current / name
|
||||
if candidate.is_file():
|
||||
try:
|
||||
with candidate.open(encoding="utf-8") as f:
|
||||
data = tomlkit.load(f)
|
||||
if data.get("tool", {}).get("codeflash"):
|
||||
return candidate
|
||||
except Exception:
|
||||
continue
|
||||
parent = current.parent
|
||||
if parent == current:
|
||||
break
|
||||
current = parent
|
||||
return None
|
||||
|
||||
def resolve_test_config(self) -> TestConfig:
|
||||
tests_root = self.project_root / self.tests_root
|
||||
if not tests_root.is_dir():
|
||||
msg = f"tests-root '{self.tests_root}' does not exist under {self.project_root}"
|
||||
raise FileNotFoundError(msg)
|
||||
# Compute tests_project_rootdir by walking up from tests_root to find pyproject.toml
|
||||
# This is the pytest rootdir used for resolving test module paths
|
||||
tests_project_rootdir = self.find_tests_project_rootdir(tests_root)
|
||||
return TestConfig(
|
||||
tests_root=tests_root,
|
||||
project_root=self.project_root,
|
||||
tests_project_rootdir=tests_project_rootdir,
|
||||
test_command=self.pytest_cmd,
|
||||
)
|
||||
|
||||
def find_tests_project_rootdir(self, tests_root: Path) -> Path:
|
||||
"""Walk up from tests_root looking for a directory containing pyproject.toml or codeflash.toml."""
|
||||
current = tests_root.resolve()
|
||||
while current != current.parent:
|
||||
for name in CONFIG_FILE_NAMES:
|
||||
if (current / name).is_file():
|
||||
return current
|
||||
current = current.parent
|
||||
return self.project_root
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from codeflash_core.danom.new_type import new_type
|
||||
from codeflash_core.danom.result import Err, Ok, Result
|
||||
from codeflash_core.danom.safe import safe, safe_method
|
||||
from codeflash_core.danom.stream import Stream
|
||||
from codeflash_core.danom.utils import all_of, any_of, compose, identity, invert, none_of
|
||||
|
||||
__all__ = [
|
||||
"Err",
|
||||
"Ok",
|
||||
"Result",
|
||||
"Stream",
|
||||
"all_of",
|
||||
"any_of",
|
||||
"compose",
|
||||
"identity",
|
||||
"invert",
|
||||
"new_type",
|
||||
"none_of",
|
||||
"safe",
|
||||
"safe_method",
|
||||
]
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
P = ParamSpec("P")
|
||||
C = TypeVar("C", bound=Callable[P, object])
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def new_type(
|
||||
name: str,
|
||||
base_type: type,
|
||||
validators: Callable | Sequence[Callable] | None = None,
|
||||
converters: Callable | Sequence[Callable] | None = None,
|
||||
*,
|
||||
frozen: bool = True,
|
||||
) -> type:
|
||||
kwargs = _callables_to_kwargs(base_type, validators, converters)
|
||||
|
||||
@attrs.define(frozen=frozen, eq=True, hash=frozen)
|
||||
class _Wrapper:
|
||||
inner: T = attrs.field(**kwargs) # type: ignore[no-matching-overload]
|
||||
|
||||
def map(self, func: Callable[[T], T]) -> Self:
|
||||
return self.__class__(func(self.inner)) # type: ignore[invalid-argument-type]
|
||||
|
||||
locals().update(_create_forward_methods(base_type))
|
||||
|
||||
_Wrapper.__name__ = name
|
||||
_Wrapper.__qualname__ = name
|
||||
return _Wrapper
|
||||
|
||||
|
||||
def _create_forward_methods(base_type: type) -> dict[str, Callable]:
|
||||
methods: dict[str, Callable] = {}
|
||||
for attr_name, _ in inspect.getmembers(base_type, inspect.isroutine):
|
||||
if attr_name.startswith("_"):
|
||||
continue
|
||||
|
||||
def make_forwarder(name: str) -> Callable:
|
||||
def method(self, *args: tuple, **kwargs: dict) -> T:
|
||||
return getattr(self.inner, name)(*args, **kwargs)
|
||||
|
||||
method.__name__ = name
|
||||
method.__doc__ = getattr(base_type, name).__doc__
|
||||
return method
|
||||
|
||||
methods[attr_name] = make_forwarder(attr_name)
|
||||
return methods
|
||||
|
||||
|
||||
def _callables_to_kwargs(
|
||||
base_type: type, validators: Callable | Sequence[Callable] | None, converters: Callable | Sequence[Callable] | None
|
||||
) -> dict[str, Sequence[Callable]]:
|
||||
kwargs = {"validator": [attrs.validators.instance_of(base_type)], "converter": []}
|
||||
kwargs["validator"] += [_validate_bool_func(fn) for fn in _to_list(validators)]
|
||||
kwargs["converter"] += _to_list(converters)
|
||||
|
||||
return {k: v for k, v in kwargs.items() if v}
|
||||
|
||||
|
||||
def _validate_bool_func(bool_fn: Callable[[T], bool]) -> Callable[[attrs.AttrsInstance, attrs.Attribute, T], None]:
|
||||
if not callable(bool_fn):
|
||||
raise TypeError("provided boolean function must be callable")
|
||||
|
||||
@wraps(bool_fn)
|
||||
def wrapper(_instance: attrs.AttrsInstance, attribute: attrs.Attribute, value: T) -> None:
|
||||
if not bool_fn(value):
|
||||
msg = f"{attribute.name} does not return True for the given boolean function, received `{value}`."
|
||||
raise ValueError(msg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _to_list(value: C | Sequence[C] | None) -> list[C]:
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
if callable(value):
|
||||
return [value] # type: ignore[invalid-return-type]
|
||||
|
||||
if isinstance(value, Sequence) and not all(callable(fn) for fn in value):
|
||||
msg = f"Given items are not all callable: {value = }"
|
||||
raise TypeError(msg)
|
||||
|
||||
return list(value)
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
|
||||
|
||||
try:
|
||||
from typing import Never # type: ignore[unresolved-import]
|
||||
except ImportError:
|
||||
from typing import NoReturn as Never
|
||||
|
||||
import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
U_co = TypeVar("U_co", covariant=True)
|
||||
E_co = TypeVar("E_co", bound=object, covariant=True)
|
||||
F_co = TypeVar("F_co", bound=object, covariant=True)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec, Self
|
||||
|
||||
P = ParamSpec("P")
|
||||
Mappable = Callable[Concatenate[T_co, P], U_co]
|
||||
Bindable = Callable[Concatenate[T_co, P], "Result[U_co, E_co]"]
|
||||
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class Result(ABC, Generic[T_co, E_co]):
|
||||
"""`Result` monad. Consists of `Ok` and `Err` for successful and failed operations respectively.
|
||||
|
||||
Each monad is a frozen instance to prevent further mutation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def unit(cls, inner: T_co) -> Ok[T_co]:
|
||||
return Ok(inner)
|
||||
|
||||
@abstractmethod
|
||||
def is_ok(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
|
||||
|
||||
@abstractmethod
|
||||
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
|
||||
|
||||
@abstractmethod
|
||||
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
|
||||
|
||||
@abstractmethod
|
||||
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
|
||||
|
||||
@abstractmethod
|
||||
def unwrap(self) -> T_co: ...
|
||||
|
||||
@staticmethod
|
||||
def result_is_ok(result: Result[T_co, E_co]) -> bool:
|
||||
return result.is_ok()
|
||||
|
||||
@staticmethod
|
||||
def result_unwrap(result: Result[T_co, E_co]) -> T_co:
|
||||
return result.unwrap()
|
||||
|
||||
|
||||
@attrs.define(frozen=True, hash=True)
|
||||
class Ok(Result[T_co, Never]):
|
||||
inner: Any = attrs.field(default=None)
|
||||
|
||||
def is_ok(self) -> Literal[True]:
|
||||
return True
|
||||
|
||||
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[U_co]:
|
||||
return Ok(func(self.inner, *args, **kwargs))
|
||||
|
||||
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self:
|
||||
return self
|
||||
|
||||
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
|
||||
return func(self.inner, *args, **kwargs)
|
||||
|
||||
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self:
|
||||
return self
|
||||
|
||||
def unwrap(self) -> T_co:
|
||||
return self.inner
|
||||
|
||||
|
||||
SafeArgs = tuple[tuple[Any, ...], dict[str, Any]]
|
||||
SafeMethodArgs = tuple[object, tuple[Any, ...], dict[str, Any]]
|
||||
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class Err(Result[Never, E_co]):
|
||||
error: Any = attrs.field(default=None)
|
||||
input_args: tuple[()] | SafeArgs | SafeMethodArgs = attrs.field(
|
||||
default=(), validator=instance_of(tuple), repr=False
|
||||
)
|
||||
traceback: str = attrs.field(default="", validator=instance_of(str))
|
||||
details: list[dict[str, Any]] = attrs.field(factory=list, init=False, repr=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
if isinstance(self.error, Exception):
|
||||
object.__setattr__(self, "details", self._extract_details(self.error.__traceback__))
|
||||
|
||||
def _extract_details(self, tb: TracebackType | None) -> list[dict[str, Any]]:
|
||||
trace_info = []
|
||||
while tb:
|
||||
frame = tb.tb_frame
|
||||
trace_info.append(
|
||||
{
|
||||
"file": frame.f_code.co_filename,
|
||||
"func": frame.f_code.co_name,
|
||||
"line_no": tb.tb_lineno,
|
||||
"locals": frame.f_locals,
|
||||
}
|
||||
)
|
||||
tb = tb.tb_next
|
||||
return trace_info
|
||||
|
||||
def is_ok(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self:
|
||||
return self
|
||||
|
||||
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[F_co]:
|
||||
return Err(func(self.error, *args, **kwargs))
|
||||
|
||||
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self:
|
||||
return self
|
||||
|
||||
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
|
||||
return func(self.error, *args, **kwargs)
|
||||
|
||||
def unwrap(self) -> T_co:
|
||||
if isinstance(self.error, Exception):
|
||||
raise self.error
|
||||
msg = f"Err does not have a caught error to raise: {self.error = }"
|
||||
raise ValueError(msg)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Err):
|
||||
return False
|
||||
|
||||
return all(
|
||||
(
|
||||
type(self.error) is type(other.error),
|
||||
str(self.error) == str(other.error),
|
||||
self.input_args == other.input_args,
|
||||
)
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f"{type(self.error)}{self.error}{self.input_args}")
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash_core.danom.result import Err, Ok
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from codeflash_core.danom.result import Result
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
U = TypeVar("U")
|
||||
E = TypeVar("E")
|
||||
|
||||
|
||||
def safe(func: Callable[P, U]) -> Callable[P, Result[U, Exception]]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[U, Exception]:
|
||||
try:
|
||||
return Ok(func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
return Err(error=e, input_args=(args, kwargs), traceback=traceback.format_exc())
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def safe_method(func: Callable[Concatenate[T, P], U]) -> Callable[Concatenate[T, P], Result[U, Exception]]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> Result[U, Exception]:
|
||||
try:
|
||||
return Ok(func(self, *args, **kwargs))
|
||||
except Exception as e:
|
||||
return Err(error=e, input_args=(self, args, kwargs), traceback=traceback.format_exc())
|
||||
|
||||
return wrapper
|
||||
|
|
@ -1,204 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from enum import Enum, auto
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, TypeVar, Union, cast
|
||||
|
||||
try:
|
||||
from itertools import batched # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from itertools import islice
|
||||
|
||||
def batched(iterable, n, *, strict=False): # noqa: ANN201
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
iterator = iter(iterable)
|
||||
while batch := tuple(islice(iterator, n)):
|
||||
if strict and len(batch) != n:
|
||||
raise ValueError("batched(): incomplete batch")
|
||||
yield batch
|
||||
|
||||
|
||||
import attrs
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
E = TypeVar("E")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable, Generator
|
||||
|
||||
MapFn = Callable[[T], U]
|
||||
FilterFn = Callable[[T], bool]
|
||||
TapFn = Callable[[T], None]
|
||||
|
||||
AsyncMapFn = Callable[[T], Awaitable[U]]
|
||||
AsyncFilterFn = Callable[[T], Awaitable[bool]]
|
||||
AsyncTapFn = Callable[[T], Awaitable[None]]
|
||||
|
||||
StreamFn = Union[MapFn, FilterFn, TapFn]
|
||||
AsyncStreamFn = Union[AsyncMapFn, AsyncFilterFn, AsyncTapFn]
|
||||
|
||||
PlannedOps = tuple[str, StreamFn]
|
||||
AsyncPlannedOps = tuple[str, AsyncStreamFn]
|
||||
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class _BaseStream(ABC):
|
||||
seq: tuple = attrs.field(validator=attrs.validators.instance_of(tuple))
|
||||
ops: tuple = attrs.field(default=(), validator=attrs.validators.instance_of(tuple), repr=False)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_iterable(cls, it: Iterable) -> _BaseStream[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def map(self, *fns: MapFn | AsyncMapFn) -> _BaseStream[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, *fns: FilterFn | AsyncFilterFn) -> _BaseStream[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def tap(self, *fns: TapFn | AsyncTapFn) -> _BaseStream[T]: ...
|
||||
|
||||
@abstractmethod
|
||||
def partition(self, fn: FilterFn) -> tuple[_BaseStream[T], _BaseStream[U]]: ...
|
||||
|
||||
@abstractmethod
|
||||
def fold(self, initial: T, fn: Callable[[T, U], T], *, workers: int = 1, use_threads: bool = False) -> T: ...
|
||||
|
||||
@abstractmethod
|
||||
def collect(self) -> tuple[U, ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
def par_collect(self, workers: int = 4, *, use_threads: bool = False) -> tuple[U, ...]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def async_collect(self) -> Awaitable[tuple[U, ...]]: ...
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.seq)
|
||||
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class Stream(_BaseStream):
|
||||
@classmethod
|
||||
def from_iterable(cls, it: Iterable) -> Stream[T]:
|
||||
if not isinstance(it, Iterable):
|
||||
it = [it]
|
||||
return cls(seq=tuple(it))
|
||||
|
||||
def map(self, *fns: MapFn | AsyncMapFn) -> Stream[U]:
|
||||
plan = (*self.ops, *tuple((_MAP, fn) for fn in fns))
|
||||
object.__setattr__(self, "ops", plan)
|
||||
return self
|
||||
|
||||
def filter(self, *fns: FilterFn | AsyncFilterFn) -> Stream[T]:
|
||||
plan = (*self.ops, *tuple((_FILTER, fn) for fn in fns))
|
||||
object.__setattr__(self, "ops", plan)
|
||||
return self
|
||||
|
||||
def tap(self, *fns: TapFn | AsyncTapFn) -> Stream[T]:
|
||||
plan = (*self.ops, *tuple((_TAP, fn) for fn in fns))
|
||||
object.__setattr__(self, "ops", plan)
|
||||
return self
|
||||
|
||||
def partition(self, fn: FilterFn, *, workers: int = 1, use_threads: bool = False) -> tuple[Stream[T], Stream[U]]:
|
||||
if workers > 1:
|
||||
seq_tuple = self.par_collect(workers=workers, use_threads=use_threads)
|
||||
else:
|
||||
seq_tuple = self.collect()
|
||||
return (Stream(seq=tuple(x for x in seq_tuple if fn(x))), Stream(seq=tuple(x for x in seq_tuple if not fn(x))))
|
||||
|
||||
def fold(self, initial: T, fn: Callable[[T, U], T], *, workers: int = 1, use_threads: bool = False) -> T:
|
||||
if workers > 1:
|
||||
return reduce(fn, self.par_collect(workers=workers, use_threads=use_threads), initial)
|
||||
return reduce(fn, self.collect(), initial)
|
||||
|
||||
def collect(self) -> tuple[U, ...]:
|
||||
return tuple(_apply_fns(self.seq, self.ops))
|
||||
|
||||
def par_collect(self, workers: int = 4, *, use_threads: bool = False) -> tuple[U, ...]:
|
||||
if workers == -1:
|
||||
workers = (os.cpu_count() or 5) - 1
|
||||
|
||||
executor_cls = ThreadPoolExecutor if use_threads else ProcessPoolExecutor
|
||||
|
||||
batches = [(list(chunk), self.ops) for chunk in batched(self.seq, n=max(4, len(self.seq) // workers))]
|
||||
|
||||
with executor_cls(max_workers=workers) as ex:
|
||||
return cast("tuple[U, ...]", tuple(itertools.chain.from_iterable(ex.map(_apply_fns_worker, batches))))
|
||||
|
||||
async def async_collect(self) -> Awaitable[tuple[U, ...]]:
|
||||
if not self.ops:
|
||||
return cast("Awaitable[tuple[U, ...]]", self.collect())
|
||||
|
||||
res = await asyncio.gather(*(_async_apply_fns(x, self.ops) for x in self.seq))
|
||||
return cast("Awaitable[tuple[U, ...]]", tuple(elem for elem in res if elem != _Nothing.NOTHING))
|
||||
|
||||
|
||||
_MAP = 0
|
||||
_FILTER = 1
|
||||
_TAP = 2
|
||||
|
||||
|
||||
class _Nothing(Enum):
|
||||
NOTHING = auto()
|
||||
|
||||
|
||||
def _apply_fns_worker(args: tuple[tuple[T], tuple[PlannedOps, ...]]) -> tuple[T]:
|
||||
seq, ops = args
|
||||
return _par_apply_fns(seq, ops)
|
||||
|
||||
|
||||
def _apply_fns(elements: tuple[T], ops: tuple[PlannedOps, ...]) -> Generator[T, None, None]:
|
||||
for elem in elements:
|
||||
valid = True
|
||||
res = elem
|
||||
for op, op_fn in ops:
|
||||
if op == _MAP:
|
||||
res = op_fn(res)
|
||||
elif op == _FILTER and not op_fn(res):
|
||||
valid = False
|
||||
break
|
||||
elif op == _TAP:
|
||||
op_fn(deepcopy(res))
|
||||
if valid:
|
||||
yield res
|
||||
|
||||
|
||||
def _par_apply_fns(elements: tuple[T], ops: tuple[PlannedOps, ...]) -> tuple[T]:
|
||||
results = []
|
||||
for elem in elements:
|
||||
valid = True
|
||||
res = elem
|
||||
for op, op_fn in ops:
|
||||
if op == _MAP:
|
||||
res = op_fn(res)
|
||||
elif op == _FILTER and not op_fn(res):
|
||||
valid = False
|
||||
break
|
||||
elif op == _TAP:
|
||||
op_fn(deepcopy(res))
|
||||
if valid:
|
||||
results.append(res)
|
||||
return tuple(results)
|
||||
|
||||
|
||||
async def _async_apply_fns(elem: T, ops: tuple[AsyncPlannedOps, ...]) -> T | _Nothing:
|
||||
res = elem
|
||||
for op, op_fn in ops:
|
||||
if op == _MAP:
|
||||
res = await op_fn(res)
|
||||
elif op == _FILTER and not await op_fn(res):
|
||||
return _Nothing.NOTHING
|
||||
elif op == _TAP:
|
||||
await op_fn(deepcopy(res))
|
||||
return res
|
||||
|
|
@ -1,68 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from functools import reduce
|
||||
from operator import not_
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import attrs
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
U_co = TypeVar("U_co", covariant=True)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
Composable = Callable[[T_co], T_co | U_co]
|
||||
Filterable = Callable[[T_co], bool]
|
||||
|
||||
|
||||
@attrs.define(frozen=True, hash=True, eq=True)
|
||||
class _Compose:
|
||||
fns: Sequence[Composable]
|
||||
|
||||
def __call__(self, initial: T_co) -> T_co | U_co:
|
||||
return reduce(_apply, self.fns, initial)
|
||||
|
||||
|
||||
def _apply(value: T_co, fn: Composable) -> T_co | U_co:
|
||||
return fn(value)
|
||||
|
||||
|
||||
def compose(*fns: Composable) -> Composable:
|
||||
return _Compose(fns)
|
||||
|
||||
|
||||
@attrs.define(frozen=True, hash=True, eq=True)
|
||||
class _AllOf:
|
||||
fns: Sequence[Filterable]
|
||||
|
||||
def __call__(self, item: T_co) -> bool:
|
||||
return all(fn(item) for fn in self.fns)
|
||||
|
||||
|
||||
def all_of(*fns: Filterable) -> Filterable:
|
||||
return _AllOf(fns)
|
||||
|
||||
|
||||
@attrs.define(frozen=True, hash=True, eq=True)
|
||||
class _AnyOf:
|
||||
fns: Sequence[Filterable]
|
||||
|
||||
def __call__(self, item: T_co) -> bool:
|
||||
return any(fn(item) for fn in self.fns)
|
||||
|
||||
|
||||
def any_of(*fns: Filterable) -> Filterable:
|
||||
return _AnyOf(fns)
|
||||
|
||||
|
||||
def none_of(*fns: Filterable) -> Filterable:
|
||||
return compose(_AnyOf(fns), not_)
|
||||
|
||||
|
||||
def identity(x: T_co) -> T_co:
|
||||
return x
|
||||
|
||||
|
||||
def invert(func: Filterable) -> Filterable:
|
||||
return compose(func, not_)
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def unified_diff(original: str, optimized: str, file_path: Path, context_lines: int = 3) -> str:
|
||||
"""Generate a unified diff between original and optimized code."""
|
||||
original_lines = original.splitlines(keepends=True)
|
||||
optimized_lines = optimized.splitlines(keepends=True)
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
original_lines, optimized_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}", n=context_lines
|
||||
)
|
||||
return "".join(diff)
|
||||
|
|
@ -1,194 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestOutcomeStatus(Enum):
|
||||
PASSED = "passed"
|
||||
FAILED = "failed"
|
||||
ERROR = "error"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionParent:
|
||||
name: str
|
||||
type: str = "ClassDef"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type}:{self.name}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionToOptimize:
|
||||
function_name: str
|
||||
file_path: Path
|
||||
parents: list[FunctionParent] = field(default_factory=list)
|
||||
starting_line: int | None = None
|
||||
ending_line: int | None = None
|
||||
starting_col: int | None = None
|
||||
ending_col: int | None = None
|
||||
is_async: bool = False
|
||||
is_method: bool = False
|
||||
language: str = ""
|
||||
doc_start_line: int | None = None
|
||||
source_code: str = ""
|
||||
|
||||
@property
|
||||
def qualified_name(self) -> str:
|
||||
if not self.parents:
|
||||
return self.function_name
|
||||
parent_path = ".".join(parent.name for parent in self.parents)
|
||||
return f"{parent_path}.{self.function_name}"
|
||||
|
||||
@property
|
||||
def top_level_parent_name(self) -> str:
|
||||
return self.function_name if not self.parents else self.parents[0].name
|
||||
|
||||
@property
|
||||
def class_name(self) -> str | None:
|
||||
for parent in reversed(self.parents):
|
||||
if parent.type == "ClassDef":
|
||||
return parent.name
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HelperFunction:
|
||||
name: str
|
||||
qualified_name: str
|
||||
file_path: Path
|
||||
source_code: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeContext:
|
||||
target_function: FunctionToOptimize
|
||||
target_code: str
|
||||
target_file: Path
|
||||
helper_functions: list[HelperFunction] = field(default_factory=list)
|
||||
read_only_context: str = ""
|
||||
imports: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Candidate:
|
||||
code: str
|
||||
explanation: str
|
||||
candidate_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
source: str = ""
|
||||
parent_id: str = ""
|
||||
code_markdown: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestOutcome:
|
||||
test_id: str
|
||||
status: TestOutcomeStatus
|
||||
output: Any = None
|
||||
duration: float = 0.0
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResults:
|
||||
passed: bool
|
||||
outcomes: list[TestOutcome] = field(default_factory=list)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkResults:
|
||||
timings: dict[str, float] = field(default_factory=dict)
|
||||
total_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredCandidate:
|
||||
candidate: Candidate
|
||||
test_results: TestResults
|
||||
benchmark_results: BenchmarkResults
|
||||
speedup: float
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationResult:
|
||||
function: FunctionToOptimize
|
||||
original_code: str
|
||||
optimized_code: str
|
||||
speedup: float
|
||||
candidate: Candidate
|
||||
test_results: TestResults
|
||||
benchmark_results: BenchmarkResults
|
||||
diff: str = ""
|
||||
explanation: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedTestFile:
|
||||
behavior_test_path: Path
|
||||
perf_test_path: Path
|
||||
behavior_test_source: str
|
||||
perf_test_source: str
|
||||
original_test_source: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedTestSuite:
|
||||
test_files: list[GeneratedTestFile] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def behavior_test_paths(self) -> list[Path]:
|
||||
return [f.behavior_test_path for f in self.test_files]
|
||||
|
||||
@property
|
||||
def perf_test_paths(self) -> list[Path]:
|
||||
return [f.perf_test_path for f in self.test_files]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCoverage:
|
||||
name: str
|
||||
coverage: float
|
||||
executed_lines: list[int] = field(default_factory=list)
|
||||
unexecuted_lines: list[int] = field(default_factory=list)
|
||||
executed_branches: list[list[int]] = field(default_factory=list)
|
||||
unexecuted_branches: list[list[int]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoverageData:
|
||||
file_path: Path
|
||||
coverage: float
|
||||
function_name: str
|
||||
main_func_coverage: FunctionCoverage
|
||||
dependent_func_coverage: FunctionCoverage | None = None
|
||||
threshold_percentage: float = 60.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestDiff:
|
||||
test_id: str
|
||||
baseline_output: Any = None
|
||||
candidate_output: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestRepairInfo:
|
||||
function_name: str
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestReviewResult:
|
||||
test_index: int
|
||||
functions_to_repair: list[TestRepairInfo] = field(default_factory=list)
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash_core.config import HIGH_EFFORT_TOP_N, EffortLevel
|
||||
from codeflash_core.strategy import DefaultStrategy
|
||||
from codeflash_core.strategy_utils import OptimizationRuntime
|
||||
from codeflash_core.ui import console, paneled_text, progress_bar
|
||||
from codeflash_core.ui import logger as ui_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash_core.config import CoreConfig
|
||||
from codeflash_core.models import FunctionToOptimize, OptimizationResult
|
||||
from codeflash_core.protocols import LanguagePlugin
|
||||
from codeflash_core.strategy_utils import OptimizationStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Optimizer:
|
||||
"""Core optimization orchestrator.
|
||||
|
||||
Drives the discover -> index -> rank -> per-function optimization loop.
|
||||
Delegates the actual optimization pipeline for each function to the active
|
||||
OptimizationStrategy (defaults to DefaultStrategy).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: CoreConfig, plugin: LanguagePlugin, strategy: OptimizationStrategy | None = None
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.plugin = plugin
|
||||
self.test_config = config.resolve_test_config()
|
||||
# Resolve the plugin's output comparator once (or None -> fallback to ==)
|
||||
self.output_comparator = getattr(plugin, "compare_outputs", None)
|
||||
self.cancel_event = threading.Event()
|
||||
# Share cancel event with plugin so it can abort long-running subprocess/HTTP calls
|
||||
if hasattr(plugin, "cancel_event"):
|
||||
object.__setattr__(plugin, "cancel_event", self.cancel_event)
|
||||
self.strategy = strategy or DefaultStrategy()
|
||||
|
||||
def cancel(self) -> None:
|
||||
self.cancel_event.set()
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel_event.is_set()
|
||||
|
||||
def run(self, files: list[Path], function_filter: str | None = None) -> list[OptimizationResult]:
|
||||
"""Run the optimization pipeline on the given files.
|
||||
|
||||
Returns a list of successful optimization results.
|
||||
"""
|
||||
# Pre-run cleanup of leftover files from previous runs
|
||||
self.plugin_cleanup()
|
||||
self.cleanup_leftover_trace_files()
|
||||
|
||||
with progress_bar("Discovering functions...", transient=True):
|
||||
functions = self.plugin.discover_functions(files)
|
||||
|
||||
if function_filter:
|
||||
functions = [f for f in functions if function_filter in (f.function_name, f.qualified_name)]
|
||||
|
||||
if not functions:
|
||||
ui_logger.info("No optimizable functions found.")
|
||||
return []
|
||||
|
||||
ui_logger.info("Found %d functions to optimize.", len(functions))
|
||||
|
||||
# Pre-index source files for dependency analysis (call graph)
|
||||
source_files = list({f.file_path for f in functions})
|
||||
|
||||
def on_index_progress(result: object) -> None:
|
||||
pass
|
||||
|
||||
with progress_bar("Building call graph...", transient=True):
|
||||
self.plugin.build_index(source_files, on_progress=on_index_progress)
|
||||
|
||||
# Rank functions by impact (e.g. dependency count)
|
||||
functions = self.plugin.rank_functions(functions)
|
||||
|
||||
results: list[OptimizationResult] = []
|
||||
skipped = 0
|
||||
cancelled = False
|
||||
|
||||
try:
|
||||
for i, function in enumerate(functions):
|
||||
if self.is_cancelled():
|
||||
cancelled = True
|
||||
break
|
||||
|
||||
console.rule(f"[bold][{i + 1}/{len(functions)}] {function.qualified_name}[/bold]")
|
||||
|
||||
# Escalate top-N functions to HIGH effort when running at MEDIUM
|
||||
original_effort = self.config.effort
|
||||
if i < HIGH_EFFORT_TOP_N and self.config.effort == EffortLevel.MEDIUM.value:
|
||||
self.config.effort = EffortLevel.HIGH.value
|
||||
|
||||
result = self.optimize_function(function)
|
||||
self.config.effort = original_effort
|
||||
|
||||
if result is not None:
|
||||
results.append(result)
|
||||
ui_logger.info("Optimized %s — %.2fx speedup", function.qualified_name, result.speedup)
|
||||
else:
|
||||
skipped += 1
|
||||
ui_logger.info("No improvement found for %s", function.qualified_name)
|
||||
except KeyboardInterrupt:
|
||||
ui_logger.warning("Keyboard interrupt received. Cleaning up…")
|
||||
cancelled = True
|
||||
self.cancel_event.set()
|
||||
finally:
|
||||
self.plugin_cleanup()
|
||||
self.cleanup_leftover_trace_files()
|
||||
|
||||
console.rule()
|
||||
paneled_text(f"{len(functions)} analyzed, {len(results)} optimized, {skipped} skipped", title="Summary")
|
||||
return results
|
||||
|
||||
def optimize_function(self, function: FunctionToOptimize) -> OptimizationResult | None:
|
||||
"""Attempt to optimize a single function. Delegates to the active strategy."""
|
||||
if self.is_cancelled():
|
||||
return None
|
||||
|
||||
runtime = OptimizationRuntime(
|
||||
plugin=self.plugin,
|
||||
config=self.config,
|
||||
test_config=self.test_config,
|
||||
cancel_event=self.cancel_event,
|
||||
output_comparator=self.output_comparator,
|
||||
trace_id=str(uuid.uuid4()),
|
||||
)
|
||||
return self.strategy.optimize_function(function, runtime)
|
||||
|
||||
# -- Cleanup helpers (shared across all strategies) -------------------------
|
||||
|
||||
def cleanup_leftover_trace_files(self) -> None:
|
||||
"""Remove leftover .trace files from previous runs."""
|
||||
tests_root = self.test_config.tests_root
|
||||
if not tests_root.exists():
|
||||
return
|
||||
leftover = list(tests_root.glob("*.trace"))
|
||||
if leftover:
|
||||
logger.debug("Cleaning up %d leftover trace file(s)", len(leftover))
|
||||
for p in leftover:
|
||||
with contextlib.suppress(OSError):
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
def plugin_cleanup(self) -> None:
|
||||
"""Delegate cleanup of language-specific leftover files to the plugin."""
|
||||
if hasattr(self.plugin, "cleanup_run"):
|
||||
try:
|
||||
self.plugin.cleanup_run(self.test_config.tests_root)
|
||||
except Exception:
|
||||
logger.debug("Plugin cleanup_run failed", exc_info=True)
|
||||
|
|
@ -1,290 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash_core.config import TestConfig
|
||||
from codeflash_core.models import (
|
||||
BenchmarkResults,
|
||||
Candidate,
|
||||
CodeContext,
|
||||
CoverageData,
|
||||
FunctionToOptimize,
|
||||
GeneratedTestSuite,
|
||||
OptimizationResult,
|
||||
ScoredCandidate,
|
||||
TestDiff,
|
||||
TestResults,
|
||||
TestReviewResult,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LanguagePlugin(Protocol):
|
||||
"""Protocol that language packages must implement.
|
||||
|
||||
A language plugin provides all language-specific functionality:
|
||||
discovering functions, extracting code context, running tests,
|
||||
replacing code, running benchmarks, and formatting.
|
||||
"""
|
||||
|
||||
def discover_functions(self, paths: list[Path]) -> list[FunctionToOptimize]:
|
||||
"""Discover optimizable functions in the given file paths."""
|
||||
...
|
||||
|
||||
def build_index(self, files: list[Path], on_progress: Callable[[Any], None] | None = None) -> None:
|
||||
"""Pre-index source files for dependency analysis (e.g. call graph).
|
||||
|
||||
Called after discovery, before optimization begins. Plugins that
|
||||
maintain a dependency resolver should index the given files here.
|
||||
The optional on_progress callback is called once per file with an
|
||||
implementation-defined result object.
|
||||
"""
|
||||
...
|
||||
|
||||
def rank_functions(
|
||||
self,
|
||||
functions: list[FunctionToOptimize],
|
||||
trace_file: Path | None = None,
|
||||
test_counts: dict[tuple[Path, str], int] | None = None,
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Rank functions in optimization order (most impactful first).
|
||||
|
||||
Ranking priority:
|
||||
1. Trace-based addressable time (when trace_file is provided)
|
||||
2. Dependency count from the call graph (fallback)
|
||||
3. Existing unit test count as secondary sort key (when test_counts provided)
|
||||
|
||||
Returns the functions unchanged if no ranking is possible.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_dependency_counts(self) -> dict[str, int]:
|
||||
"""Return {qualified_name: callee_count} from the most recent ranking.
|
||||
|
||||
Called after rank_functions() so the UI can display per-function
|
||||
dependency information. Plugins that don't track a call graph may
|
||||
return an empty dict.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_candidates(self, context: CodeContext, trace_id: str = "") -> list[Candidate]:
|
||||
"""Request optimization candidates from the AI service."""
|
||||
...
|
||||
|
||||
def extract_context(self, function: FunctionToOptimize) -> CodeContext:
|
||||
"""Extract all code context needed to optimize a function."""
|
||||
...
|
||||
|
||||
@overload
|
||||
def run_tests(
|
||||
self,
|
||||
test_config: TestConfig,
|
||||
test_files: list[Path] | None = ...,
|
||||
test_iteration: int = ...,
|
||||
enable_coverage: Literal[False] = ...,
|
||||
) -> TestResults: ...
|
||||
|
||||
@overload
|
||||
def run_tests(
|
||||
self,
|
||||
test_config: TestConfig,
|
||||
test_files: list[Path] | None = ...,
|
||||
test_iteration: int = ...,
|
||||
enable_coverage: Literal[True] = ...,
|
||||
) -> tuple[TestResults, CoverageData | None]: ...
|
||||
|
||||
def run_tests(
|
||||
self,
|
||||
test_config: TestConfig,
|
||||
test_files: list[Path] | None = None,
|
||||
test_iteration: int = 0,
|
||||
enable_coverage: bool = False,
|
||||
) -> TestResults | tuple[TestResults, CoverageData | None]:
|
||||
"""Run tests and return structured results.
|
||||
|
||||
If test_files is provided, run only those files.
|
||||
Otherwise discover test files from test_config.
|
||||
test_iteration is passed as CODEFLASH_TEST_ITERATION env var.
|
||||
When enable_coverage is True, returns (TestResults, CoverageData | None).
|
||||
"""
|
||||
...
|
||||
|
||||
def replace_function(self, file: Path, function: FunctionToOptimize, new_code: str) -> None:
|
||||
"""Replace a function's source code in a file."""
|
||||
...
|
||||
|
||||
def restore_function(self, file: Path, function: FunctionToOptimize, original_code: str) -> None:
|
||||
"""Restore a function's original source code in a file."""
|
||||
...
|
||||
|
||||
def run_benchmarks(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
test_config: TestConfig,
|
||||
test_files: list[Path] | None = None,
|
||||
test_iteration: int = 0,
|
||||
) -> BenchmarkResults:
|
||||
"""Run benchmarks for a function and return timing data.
|
||||
|
||||
If test_files is provided, run only those files.
|
||||
Otherwise discover test files from test_config.
|
||||
test_iteration is passed as CODEFLASH_TEST_ITERATION env var.
|
||||
"""
|
||||
...
|
||||
|
||||
def format_code(self, code: str, file: Path) -> str:
|
||||
"""Format code according to the project's style."""
|
||||
...
|
||||
|
||||
def validate_candidate(self, code: str) -> bool:
|
||||
"""Return True if the candidate code is syntactically valid."""
|
||||
...
|
||||
|
||||
def normalize_code(self, code: str) -> str:
|
||||
"""Normalize code for deduplication (e.g. strip comments, whitespace, docstrings)."""
|
||||
...
|
||||
|
||||
# -- Phase 1: Test Generation -----------------------------------------------
|
||||
|
||||
def generate_tests(
|
||||
self, function: FunctionToOptimize, context: CodeContext, test_config: TestConfig, trace_id: str = ""
|
||||
) -> GeneratedTestSuite | None:
|
||||
"""Generate regression tests for the target function."""
|
||||
...
|
||||
|
||||
# -- Phase 2: Split behavioral / performance test running --------------------
|
||||
|
||||
def run_behavioral_tests(self, test_files: list[Path], test_config: TestConfig) -> TestResults:
|
||||
"""Run behavioral tests and return pass/fail with captured outputs."""
|
||||
...
|
||||
|
||||
def run_performance_tests(
|
||||
self, test_files: list[Path], function: FunctionToOptimize, test_config: TestConfig
|
||||
) -> BenchmarkResults:
|
||||
"""Run performance-instrumented tests and return timing data."""
|
||||
...
|
||||
|
||||
# -- Phase 3: Multi-round candidate generation ------------------------------
|
||||
|
||||
def run_line_profiler(
|
||||
self, function: FunctionToOptimize, test_config: TestConfig, test_files: list[Path] | None = None
|
||||
) -> str:
|
||||
"""Run line profiler on the function and return formatted profiler output.
|
||||
|
||||
Returns an empty string if profiling is not possible (e.g. JIT-decorated code).
|
||||
"""
|
||||
...
|
||||
|
||||
def get_line_profiler_candidates(
|
||||
self, context: CodeContext, line_profile_data: str, trace_id: str = ""
|
||||
) -> list[Candidate]:
|
||||
"""Generate candidates guided by line profiler hotspot data."""
|
||||
...
|
||||
|
||||
def repair_candidate(
|
||||
self, context: CodeContext, candidate: Candidate, test_diffs: list[TestDiff], trace_id: str = ""
|
||||
) -> Candidate | None:
|
||||
"""Fix a failing candidate using test failure info."""
|
||||
...
|
||||
|
||||
def refine_candidate(
|
||||
self, context: CodeContext, candidate: ScoredCandidate, baseline_bench: BenchmarkResults, trace_id: str = ""
|
||||
) -> list[Candidate]:
|
||||
"""Refine a passing candidate for further improvement."""
|
||||
...
|
||||
|
||||
def adaptive_optimize(
|
||||
self, context: CodeContext, scored: list[ScoredCandidate], trace_id: str = ""
|
||||
) -> Candidate | None:
|
||||
"""Combine insights from evaluated candidates."""
|
||||
...
|
||||
|
||||
# -- Phase 4: Test review & repair ------------------------------------------
|
||||
|
||||
def review_generated_tests(
|
||||
self, suite: GeneratedTestSuite, context: CodeContext, test_results: TestResults, trace_id: str = ""
|
||||
) -> list[TestReviewResult]:
|
||||
"""Review generated tests for quality issues."""
|
||||
...
|
||||
|
||||
def repair_generated_tests(
|
||||
self,
|
||||
suite: GeneratedTestSuite,
|
||||
reviews: list[TestReviewResult],
|
||||
context: CodeContext,
|
||||
trace_id: str = "",
|
||||
previous_repair_errors: dict[str, str] | None = None,
|
||||
coverage_data: CoverageData | None = None,
|
||||
) -> GeneratedTestSuite | None:
|
||||
"""Repair generated tests based on review feedback."""
|
||||
...
|
||||
|
||||
# -- Phase 5: AI-assisted ranking & explanation -----------------------------
|
||||
|
||||
def rank_candidates(
|
||||
self, scored: list[ScoredCandidate], context: CodeContext, trace_id: str = ""
|
||||
) -> list[int] | None:
|
||||
"""Rank candidates using AI. Returns indices in decreasing preference order, or None."""
|
||||
...
|
||||
|
||||
def generate_explanation(
|
||||
self, result: OptimizationResult, context: CodeContext, trace_id: str = "", annotated_tests: str = ""
|
||||
) -> str:
|
||||
"""Generate a human-readable explanation for the winning optimization."""
|
||||
...
|
||||
|
||||
# -- Cleanup & environment -------------------------------------------------
|
||||
|
||||
def cleanup_run(self, tests_root: Path) -> None:
|
||||
"""Clean up leftover files from previous or current runs.
|
||||
|
||||
Called before and after the optimization loop. Implementations should
|
||||
remove instrumented test files, temporary return-value files, trace
|
||||
files, and any shared temp directories their tooling creates.
|
||||
"""
|
||||
...
|
||||
|
||||
def compare_outputs(self, baseline_output: object, candidate_output: object) -> bool:
|
||||
"""Compare two captured test outputs for equivalence.
|
||||
|
||||
Called during verification to decide whether a candidate preserved
|
||||
behavior. Implementations may use deep/structural comparison
|
||||
(e.g. handling NaN, custom objects). The default core fallback is ``==``.
|
||||
"""
|
||||
...
|
||||
|
||||
def validate_environment(self, config: Any) -> bool:
|
||||
"""Validate that the environment is ready to run optimizations.
|
||||
|
||||
Called before the optimization loop starts. Implementations should
|
||||
check that required tools (formatters, test runners, etc.) are
|
||||
installed and accessible. Return True if everything is OK.
|
||||
"""
|
||||
...
|
||||
|
||||
# -- Phase 6: PR creation & result logging ----------------------------------
|
||||
|
||||
def create_pr(
|
||||
self,
|
||||
result: OptimizationResult,
|
||||
context: CodeContext,
|
||||
trace_id: str = "",
|
||||
generated_tests: GeneratedTestSuite | None = None,
|
||||
) -> str | None:
|
||||
"""Create a pull request with the optimization. Returns the PR URL or None."""
|
||||
...
|
||||
|
||||
def log_results(
|
||||
self,
|
||||
result: OptimizationResult,
|
||||
trace_id: str,
|
||||
all_speedups: dict[str, float] | None = None,
|
||||
all_runtimes: dict[str, float] | None = None,
|
||||
all_correct: dict[str, bool] | None = None,
|
||||
) -> None:
|
||||
"""Log optimization results to the backend."""
|
||||
...
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_core.models import BenchmarkResults, ScoredCandidate
|
||||
|
||||
|
||||
def compute_speedup(baseline: BenchmarkResults, candidate: BenchmarkResults) -> float:
|
||||
"""Compute speedup as percentage improvement: (baseline - candidate) / candidate.
|
||||
|
||||
Matches original codeflash performance_gain formula.
|
||||
Returns 0.0 if candidate time is zero (no improvement measurable).
|
||||
A positive value means the candidate is faster (e.g. 1.0 = 100% faster).
|
||||
"""
|
||||
if candidate.total_time <= 0:
|
||||
return 0.0
|
||||
return (baseline.total_time - candidate.total_time) / candidate.total_time
|
||||
|
||||
|
||||
def score_candidate(speedup: float) -> float:
|
||||
"""Score a candidate based on its speedup.
|
||||
|
||||
Currently score == speedup. This is the extension point for adding
|
||||
more signals (code complexity, diff size, etc.) in the future.
|
||||
"""
|
||||
return speedup
|
||||
|
||||
|
||||
def select_best(candidates: list[ScoredCandidate]) -> ScoredCandidate | None:
|
||||
"""Select the best candidate by score. Returns None if list is empty."""
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda c: c.score)
|
||||
|
|
@ -1,311 +0,0 @@
|
|||
"""Optimization strategies for the core optimizer.
|
||||
|
||||
An OptimizationStrategy controls the full pipeline for optimizing a single
|
||||
function: context extraction, candidate generation, test review, evaluation,
|
||||
ranking, and explanation. The Optimizer delegates to the active strategy,
|
||||
keeping discovery, indexing, and the per-function loop in the orchestrator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from codeflash_core.config import EffortKeys, get_effort_value
|
||||
from codeflash_core.diff import unified_diff
|
||||
from codeflash_core.models import OptimizationResult
|
||||
from codeflash_core.strategy_evaluation import DefaultStrategyEvaluationMixin
|
||||
from codeflash_core.strategy_utils import StageSpec, cleanup_generated_tests, log_optimization_run
|
||||
from codeflash_core.ui import code_print, progress_bar
|
||||
from codeflash_core.ui import logger as ui_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash_core.models import Candidate, CodeContext, FunctionToOptimize, GeneratedTestSuite, ScoredCandidate
|
||||
from codeflash_core.strategy_utils import OptimizationRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default strategy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DefaultStrategy(DefaultStrategyEvaluationMixin):
|
||||
"""The default optimization strategy.
|
||||
|
||||
Pipeline: context -> parallel testgen + candidates -> test review/repair ->
|
||||
baseline tests & benchmarks -> multi-round candidate evaluation ->
|
||||
AI-assisted ranking -> explanation.
|
||||
"""
|
||||
|
||||
stages: ClassVar[list[StageSpec]] = [
|
||||
StageSpec("context", "Extracting source code and dependencies to build the optimization context for the AI."),
|
||||
StageSpec("generating", "Generating unit tests for correctness validation and optimized candidates."),
|
||||
StageSpec(
|
||||
"test_review", "Reviewing generated tests for correctness and repairing any issues before validation."
|
||||
),
|
||||
StageSpec(
|
||||
"baseline",
|
||||
"Running the original code against tests and benchmarks to establish a performance reference point.",
|
||||
),
|
||||
StageSpec("evaluating", "Testing each candidate for correctness, then benchmarking the ones that pass."),
|
||||
StageSpec(
|
||||
"ranking", "Comparing candidate benchmarks against the baseline to find the fastest correct optimization."
|
||||
),
|
||||
StageSpec(
|
||||
"explaining",
|
||||
"Generating a human-readable explanation of what the best optimization changed and why it's faster.",
|
||||
),
|
||||
]
|
||||
|
||||
def optimize_function(
|
||||
self, function: FunctionToOptimize, runtime: OptimizationRuntime
|
||||
) -> OptimizationResult | None:
|
||||
context: CodeContext | None = None
|
||||
candidates: list[Candidate] | None = None
|
||||
generated_tests: GeneratedTestSuite | None = None
|
||||
original_generated_tests: GeneratedTestSuite | None = None
|
||||
scored: list[ScoredCandidate] = []
|
||||
result: OptimizationResult | None = None
|
||||
stage = "context"
|
||||
exit_reason = ""
|
||||
|
||||
try:
|
||||
head = self.extract_and_generate(function, runtime)
|
||||
if head is None:
|
||||
exit_reason = "cancelled" if runtime.is_cancelled() else "no_candidates"
|
||||
return None
|
||||
context, candidates, generated_tests = head
|
||||
original_generated_tests = generated_tests
|
||||
|
||||
# Phase 2: Review and repair generated tests
|
||||
stage = "test_review"
|
||||
with progress_bar("Reviewing and repairing tests..."):
|
||||
generated_tests = self.review_and_repair_tests(generated_tests, context, runtime)
|
||||
if runtime.is_cancelled():
|
||||
exit_reason = "cancelled"
|
||||
return None
|
||||
|
||||
# Determine test files
|
||||
behavior_files: list[Path] | None = None
|
||||
perf_files: list[Path] | None = None
|
||||
if generated_tests and generated_tests.test_files:
|
||||
behavior_files = generated_tests.behavior_test_paths
|
||||
perf_files = generated_tests.perf_test_paths
|
||||
|
||||
# Run baseline tests
|
||||
stage = "baseline"
|
||||
with progress_bar("Running baseline tests..."):
|
||||
baseline_tests = runtime.plugin.run_tests(runtime.test_config, test_files=behavior_files)
|
||||
if runtime.is_cancelled():
|
||||
exit_reason = "cancelled"
|
||||
return None
|
||||
|
||||
if not baseline_tests.passed:
|
||||
exit_reason = "baseline_failed"
|
||||
ui_logger.warning("Baseline tests failed for %s, skipping", function.qualified_name)
|
||||
return None
|
||||
|
||||
# Run baseline benchmarks
|
||||
with progress_bar("Running baseline benchmarks..."):
|
||||
baseline_bench = runtime.plugin.run_benchmarks(function, runtime.test_config, test_files=perf_files)
|
||||
if runtime.is_cancelled():
|
||||
exit_reason = "cancelled"
|
||||
return None
|
||||
|
||||
# Line profiler
|
||||
lp_candidates = self.get_line_profiler_candidates(function, context, runtime, perf_files)
|
||||
if lp_candidates:
|
||||
candidates.extend(lp_candidates)
|
||||
if runtime.is_cancelled():
|
||||
exit_reason = "cancelled"
|
||||
return None
|
||||
|
||||
# Phase 3: Multi-round candidate evaluation
|
||||
stage = "evaluating"
|
||||
ui_logger.info("Evaluating %d candidates...", len(candidates))
|
||||
scored, all_speedups, all_runtimes, all_correct = self.evaluate_candidates(
|
||||
function,
|
||||
context,
|
||||
candidates,
|
||||
baseline_tests,
|
||||
baseline_bench,
|
||||
runtime=runtime,
|
||||
behavior_test_files=behavior_files,
|
||||
perf_test_files=perf_files,
|
||||
)
|
||||
if runtime.is_cancelled():
|
||||
exit_reason = "cancelled"
|
||||
return None
|
||||
|
||||
# Phase 4+5: Rank, explain, log, finish
|
||||
stage = "ranking"
|
||||
with progress_bar("Ranking results..."):
|
||||
result = self.rank_explain_finish(
|
||||
function, context, scored, generated_tests, runtime, all_speedups, all_runtimes, all_correct
|
||||
)
|
||||
if result:
|
||||
stage = "done"
|
||||
exit_reason = "optimized"
|
||||
code_print(result.diff)
|
||||
else:
|
||||
exit_reason = "cancelled" if runtime.is_cancelled() else "no_improvement"
|
||||
return result
|
||||
finally:
|
||||
log_optimization_run(
|
||||
function,
|
||||
runtime,
|
||||
context=context,
|
||||
candidates=candidates,
|
||||
generated_tests=generated_tests,
|
||||
scored=scored or None,
|
||||
result=result,
|
||||
stage_reached=stage,
|
||||
exit_reason=exit_reason,
|
||||
)
|
||||
if original_generated_tests:
|
||||
cleanup_generated_tests(original_generated_tests)
|
||||
|
||||
# -- Building blocks (override individually or call from custom strategies) --
|
||||
|
||||
def extract_and_generate(
|
||||
self, function: FunctionToOptimize, runtime: OptimizationRuntime
|
||||
) -> tuple[CodeContext, list[Candidate], GeneratedTestSuite | None] | None:
|
||||
"""Context extraction + parallel test/candidate generation.
|
||||
|
||||
Returns (context, candidates, generated_tests) or None if cancelled/no candidates.
|
||||
"""
|
||||
with progress_bar("Extracting context...", transient=True):
|
||||
context = runtime.plugin.extract_context(function)
|
||||
if runtime.is_cancelled():
|
||||
return None
|
||||
|
||||
with progress_bar("Generating tests and candidates..."):
|
||||
generated_tests, candidates = self.generate_tests_and_candidates(function, context, runtime)
|
||||
|
||||
if runtime.is_cancelled():
|
||||
return None
|
||||
if not candidates:
|
||||
ui_logger.info("No candidates returned for %s", function.qualified_name)
|
||||
return None
|
||||
|
||||
ui_logger.info("Received %d candidates for %s", len(candidates), function.qualified_name)
|
||||
return context, candidates, generated_tests
|
||||
|
||||
def rank_explain_finish(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
context: CodeContext,
|
||||
scored: list[ScoredCandidate],
|
||||
generated_tests: GeneratedTestSuite | None,
|
||||
runtime: OptimizationRuntime,
|
||||
all_speedups: dict[str, float],
|
||||
all_runtimes: dict[str, float],
|
||||
all_correct: dict[str, bool],
|
||||
) -> OptimizationResult | None:
|
||||
"""Rank candidates, explain the best, log results, and emit completion events."""
|
||||
best = self.select_best_with_ranking(scored, context, runtime)
|
||||
if runtime.is_cancelled():
|
||||
return None
|
||||
if best is None or best.speedup <= 0:
|
||||
return None
|
||||
|
||||
diff = unified_diff(context.target_code, best.candidate.code, function.file_path)
|
||||
result = OptimizationResult(
|
||||
function=function,
|
||||
original_code=context.target_code,
|
||||
optimized_code=best.candidate.code,
|
||||
speedup=best.speedup,
|
||||
candidate=best.candidate,
|
||||
test_results=best.test_results,
|
||||
benchmark_results=best.benchmark_results,
|
||||
diff=diff,
|
||||
)
|
||||
|
||||
annotated_tests = ""
|
||||
if generated_tests and generated_tests.test_files:
|
||||
annotated_tests = "\n\n".join(
|
||||
tf.original_test_source for tf in generated_tests.test_files if tf.original_test_source
|
||||
)
|
||||
|
||||
explanation = runtime.plugin.generate_explanation(
|
||||
result, context, trace_id=runtime.trace_id, annotated_tests=annotated_tests
|
||||
)
|
||||
if explanation:
|
||||
result.explanation = explanation
|
||||
|
||||
runtime.plugin.log_results(
|
||||
result, runtime.trace_id, all_speedups=all_speedups, all_runtimes=all_runtimes, all_correct=all_correct
|
||||
)
|
||||
if runtime.config.create_pr:
|
||||
try:
|
||||
runtime.plugin.create_pr(result, context, trace_id=runtime.trace_id, generated_tests=generated_tests)
|
||||
except Exception:
|
||||
logger.debug("PR creation failed", exc_info=True)
|
||||
|
||||
return result
|
||||
|
||||
def generate_tests_and_candidates(
|
||||
self, function: FunctionToOptimize, context: CodeContext, runtime: OptimizationRuntime
|
||||
) -> tuple[GeneratedTestSuite | None, list[Candidate]]:
|
||||
"""Generate tests and fetch candidates in parallel."""
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
tests_future = executor.submit(
|
||||
runtime.plugin.generate_tests, function, context, runtime.test_config, runtime.trace_id
|
||||
)
|
||||
candidates_future = executor.submit(runtime.plugin.get_candidates, context, runtime.trace_id)
|
||||
|
||||
while True:
|
||||
if runtime.is_cancelled():
|
||||
tests_future.cancel()
|
||||
candidates_future.cancel()
|
||||
break
|
||||
if tests_future.done() and candidates_future.done():
|
||||
break
|
||||
runtime.cancel_event.wait(timeout=0.1)
|
||||
|
||||
try:
|
||||
generated_tests = (
|
||||
tests_future.result() if tests_future.done() and not tests_future.cancelled() else None
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Test generation failed", exc_info=True)
|
||||
generated_tests = None
|
||||
|
||||
try:
|
||||
candidates = (
|
||||
candidates_future.result() if candidates_future.done() and not candidates_future.cancelled() else []
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Candidate generation failed", exc_info=True)
|
||||
candidates = []
|
||||
|
||||
return generated_tests, candidates
|
||||
|
||||
def get_line_profiler_candidates(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
context: CodeContext,
|
||||
runtime: OptimizationRuntime,
|
||||
perf_test_files: list[Path] | None,
|
||||
) -> list[Candidate]:
|
||||
"""Run line profiler on baseline and fetch LP-guided candidates."""
|
||||
n_lp = get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, runtime.config.effort)
|
||||
if n_lp <= 0:
|
||||
return []
|
||||
|
||||
try:
|
||||
lp_data = runtime.plugin.run_line_profiler(function, runtime.test_config, test_files=perf_test_files)
|
||||
if not lp_data:
|
||||
return []
|
||||
|
||||
lp_candidates = runtime.plugin.get_line_profiler_candidates(context, lp_data, runtime.trace_id)
|
||||
for c in lp_candidates:
|
||||
c.source = "line_profiler"
|
||||
return lp_candidates
|
||||
except Exception:
|
||||
logger.debug("Line profiler step failed for %s", function.qualified_name, exc_info=True)
|
||||
return []
|
||||
|
|
@ -1,283 +0,0 @@
|
|||
"""Mixin: candidate evaluation, filtering, and test review/repair."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash_core.config import MAX_TEST_REPAIR_CYCLES, EffortKeys, get_effort_value
|
||||
from codeflash_core.models import GeneratedTestSuite, ScoredCandidate, TestOutcomeStatus
|
||||
from codeflash_core.ranking import compute_speedup, score_candidate, select_best
|
||||
from codeflash_core.strategy_utils import MIN_CORRECT_CANDIDATES, compute_test_diffs, restore_test_snapshots
|
||||
from codeflash_core.ui import logger as ui_logger
|
||||
from codeflash_core.ui import progress_bar
|
||||
from codeflash_core.verification import is_equivalent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash_core.models import (
|
||||
BenchmarkResults,
|
||||
Candidate,
|
||||
CodeContext,
|
||||
CoverageData,
|
||||
FunctionToOptimize,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash_core.strategy_utils import OptimizationRuntime
|
||||
|
||||
StrategyBase = object
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultStrategyEvaluationMixin(StrategyBase):
|
||||
def evaluate_candidates(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
context: CodeContext,
|
||||
candidates: list[Candidate],
|
||||
baseline_tests: TestResults,
|
||||
baseline_bench: BenchmarkResults,
|
||||
runtime: OptimizationRuntime,
|
||||
behavior_test_files: list[Path] | None = None,
|
||||
perf_test_files: list[Path] | None = None,
|
||||
) -> tuple[list[ScoredCandidate], dict[str, float], dict[str, float], dict[str, bool]]:
|
||||
"""Evaluate candidates with multi-round repair, refinement, and adaptive optimization.
|
||||
|
||||
Returns (scored_candidates, all_speedups, all_runtimes, all_correct) where the dicts
|
||||
accumulate data for ALL candidates (including failed ones) keyed by candidate_id.
|
||||
"""
|
||||
effort = runtime.config.effort
|
||||
scored: list[ScoredCandidate] = []
|
||||
all_speedups: dict[str, float] = {}
|
||||
all_runtimes: dict[str, float] = {}
|
||||
all_correct: dict[str, bool] = {}
|
||||
|
||||
queue = self.filter_candidates(function, context, candidates, runtime)
|
||||
|
||||
processed = 0
|
||||
max_total = 30
|
||||
max_repairs: int = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, effort)
|
||||
repair_unmatched_limit: float = get_effort_value(EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT, effort)
|
||||
top_refinement: int = get_effort_value(EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT, effort)
|
||||
max_adaptive: int = get_effort_value(EffortKeys.MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE, effort)
|
||||
adaptive_threshold: int = get_effort_value(EffortKeys.ADAPTIVE_OPTIMIZATION_THRESHOLD, effort)
|
||||
repair_counter = 0
|
||||
adaptive_counter = 0
|
||||
|
||||
file_snapshots = self.capture_file_snapshots(function, context)
|
||||
|
||||
while queue and processed < max_total:
|
||||
if runtime.is_cancelled():
|
||||
break
|
||||
candidate = queue.pop(0)
|
||||
processed += 1
|
||||
ui_logger.info(
|
||||
"Testing candidate %d/%d [%s]",
|
||||
processed,
|
||||
len(candidates) + processed - len(queue) - 1,
|
||||
candidate.source,
|
||||
)
|
||||
|
||||
if hasattr(runtime.plugin, "pending_code_markdown"):
|
||||
runtime.plugin.pending_code_markdown = candidate.code_markdown
|
||||
runtime.plugin.replace_function(function.file_path, function, candidate.code)
|
||||
|
||||
try:
|
||||
with progress_bar(f"Running tests for candidate {processed}...", transient=True):
|
||||
test_results = runtime.plugin.run_tests(
|
||||
runtime.test_config, test_files=behavior_test_files, test_iteration=processed
|
||||
)
|
||||
|
||||
if not is_equivalent(baseline_tests, test_results, comparator=runtime.output_comparator):
|
||||
ui_logger.info("Candidate %s failed equivalence check", candidate.candidate_id)
|
||||
all_correct[candidate.candidate_id] = False
|
||||
# Phase 3: Try repair if test diffs are manageable
|
||||
# Only repair first-pass candidates (optimize, line_profiler)
|
||||
successful_count = sum(1 for v in all_correct.values() if v)
|
||||
if (
|
||||
not runtime.is_cancelled()
|
||||
and candidate.source in ("optimize", "line_profiler")
|
||||
and repair_counter < max_repairs
|
||||
and successful_count < MIN_CORRECT_CANDIDATES
|
||||
):
|
||||
diffs = compute_test_diffs(baseline_tests, test_results)
|
||||
total_tests = max(len(baseline_tests.outcomes), 1)
|
||||
if diffs and len(diffs) / total_tests <= repair_unmatched_limit:
|
||||
repaired = runtime.plugin.repair_candidate(
|
||||
context, candidate, diffs, trace_id=runtime.trace_id
|
||||
)
|
||||
if repaired is not None:
|
||||
repair_counter += 1
|
||||
queue.append(repaired)
|
||||
continue
|
||||
|
||||
with progress_bar(f"Benchmarking candidate {processed}...", transient=True):
|
||||
bench = runtime.plugin.run_benchmarks(
|
||||
function, runtime.test_config, test_files=perf_test_files, test_iteration=processed
|
||||
)
|
||||
speedup = compute_speedup(baseline_bench, bench)
|
||||
ui_logger.info("Candidate %s passed — %.2fx speedup", candidate.candidate_id, speedup)
|
||||
|
||||
all_correct[candidate.candidate_id] = True
|
||||
all_speedups[candidate.candidate_id] = speedup
|
||||
all_runtimes[candidate.candidate_id] = bench.total_time
|
||||
|
||||
scored.append(
|
||||
ScoredCandidate(
|
||||
candidate=candidate,
|
||||
test_results=test_results,
|
||||
benchmark_results=bench,
|
||||
speedup=speedup,
|
||||
score=score_candidate(speedup),
|
||||
)
|
||||
)
|
||||
|
||||
# Phase 3: Try refinement for candidates with speedup (skip already-refined)
|
||||
if not runtime.is_cancelled():
|
||||
refinement_eligible = sum(1 for s in scored if s.speedup > 0 and s.candidate.source != "refine")
|
||||
if speedup > 0 and candidate.source != "refine" and refinement_eligible <= top_refinement:
|
||||
refined = runtime.plugin.refine_candidate(
|
||||
context, scored[-1], baseline_bench, trace_id=runtime.trace_id
|
||||
)
|
||||
if refined:
|
||||
queue.extend(refined)
|
||||
|
||||
# Phase 3: Try adaptive after enough candidates scored
|
||||
if (
|
||||
not runtime.is_cancelled()
|
||||
and len(scored) >= 3
|
||||
and adaptive_counter < max_adaptive
|
||||
and adaptive_threshold > 0
|
||||
):
|
||||
adaptive_counter += 1
|
||||
adaptive = runtime.plugin.adaptive_optimize(context, scored, trace_id=runtime.trace_id)
|
||||
if adaptive is not None:
|
||||
queue.append(adaptive)
|
||||
finally:
|
||||
for path, original_content in file_snapshots.items():
|
||||
path.write_text(original_content, encoding="utf-8")
|
||||
|
||||
return scored, all_speedups, all_runtimes, all_correct
|
||||
|
||||
def filter_candidates(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
context: CodeContext,
|
||||
candidates: list[Candidate],
|
||||
runtime: OptimizationRuntime,
|
||||
) -> list[Candidate]:
|
||||
"""Validate syntax, format, and deduplicate candidates."""
|
||||
normalized_original = runtime.plugin.normalize_code(context.target_code.strip())
|
||||
seen_normalized: set[str] = {normalized_original}
|
||||
queue: list[Candidate] = []
|
||||
for c in candidates:
|
||||
if not runtime.plugin.validate_candidate(c.code):
|
||||
ui_logger.info("Candidate %s has invalid syntax, skipping", c.candidate_id)
|
||||
continue
|
||||
c.code = runtime.plugin.format_code(c.code, function.file_path)
|
||||
norm = runtime.plugin.normalize_code(c.code.strip())
|
||||
if norm in seen_normalized:
|
||||
ui_logger.info("Candidate %s is identical/duplicate, skipping", c.candidate_id)
|
||||
continue
|
||||
seen_normalized.add(norm)
|
||||
queue.append(c)
|
||||
return queue
|
||||
|
||||
def capture_file_snapshots(self, function: FunctionToOptimize, context: CodeContext) -> dict[Path, str]:
|
||||
"""Snapshot target + helper files for reliable restoration after code replacement."""
|
||||
snapshots: dict[Path, str] = {}
|
||||
if function.file_path.exists():
|
||||
snapshots[function.file_path] = function.file_path.read_text("utf-8")
|
||||
for h in context.helper_functions:
|
||||
if h.file_path.exists() and h.file_path not in snapshots:
|
||||
snapshots[h.file_path] = h.file_path.read_text("utf-8")
|
||||
return snapshots
|
||||
|
||||
def review_and_repair_tests(
|
||||
self, generated_tests: GeneratedTestSuite | None, context: CodeContext, runtime: OptimizationRuntime
|
||||
) -> GeneratedTestSuite | None:
|
||||
"""Review and repair generated tests up to MAX_TEST_REPAIR_CYCLES iterations, then drop still-failing files."""
|
||||
if not generated_tests or not generated_tests.test_files:
|
||||
return generated_tests
|
||||
|
||||
# Snapshot test file contents before repair so we can revert on failure
|
||||
pre_repair_snapshots: dict[int, tuple[str, str, str]] = {}
|
||||
for i, tf in enumerate(generated_tests.test_files):
|
||||
pre_repair_snapshots[i] = (tf.original_test_source, tf.behavior_test_source, tf.perf_test_source)
|
||||
|
||||
previous_repair_errors: dict[str, str] = {}
|
||||
coverage_data: CoverageData | None = None
|
||||
|
||||
for iteration in range(MAX_TEST_REPAIR_CYCLES):
|
||||
if runtime.is_cancelled():
|
||||
return generated_tests
|
||||
behavior_files = generated_tests.behavior_test_paths
|
||||
|
||||
# Collect coverage on first iteration for repair guidance
|
||||
if iteration == 0:
|
||||
test_results, cov_data = runtime.plugin.run_tests(
|
||||
runtime.test_config, test_files=behavior_files, enable_coverage=True
|
||||
)
|
||||
if cov_data is not None:
|
||||
coverage_data = cov_data
|
||||
else:
|
||||
test_results = runtime.plugin.run_tests(runtime.test_config, test_files=behavior_files)
|
||||
|
||||
if test_results.passed:
|
||||
return generated_tests
|
||||
|
||||
# Collect error messages from failing tests for next repair attempt
|
||||
failing_outcomes = [o for o in test_results.outcomes if o.status != TestOutcomeStatus.PASSED]
|
||||
for outcome in failing_outcomes:
|
||||
if outcome.error_message:
|
||||
previous_repair_errors[outcome.test_id] = outcome.error_message
|
||||
|
||||
reviews = runtime.plugin.review_generated_tests(generated_tests, context, test_results, runtime.trace_id)
|
||||
if not reviews or all(not r.functions_to_repair for r in reviews):
|
||||
restore_test_snapshots(generated_tests, pre_repair_snapshots)
|
||||
break
|
||||
|
||||
repaired = runtime.plugin.repair_generated_tests(
|
||||
generated_tests,
|
||||
reviews,
|
||||
context,
|
||||
runtime.trace_id,
|
||||
previous_repair_errors=previous_repair_errors or None,
|
||||
coverage_data=coverage_data,
|
||||
)
|
||||
if repaired is None:
|
||||
restore_test_snapshots(generated_tests, pre_repair_snapshots)
|
||||
break
|
||||
generated_tests = repaired
|
||||
|
||||
# Final pass: drop individual test files that still fail
|
||||
passing_files = []
|
||||
for tf in generated_tests.test_files:
|
||||
if runtime.is_cancelled():
|
||||
return generated_tests
|
||||
run_result = runtime.plugin.run_tests(runtime.test_config, test_files=[tf.behavior_test_path])
|
||||
if run_result.passed:
|
||||
passing_files.append(tf)
|
||||
else:
|
||||
ui_logger.info("Dropping failing test file: %s", tf.behavior_test_path)
|
||||
|
||||
if not passing_files:
|
||||
return None
|
||||
|
||||
return GeneratedTestSuite(test_files=passing_files)
|
||||
|
||||
def select_best_with_ranking(
|
||||
self, scored: list[ScoredCandidate], context: CodeContext, runtime: OptimizationRuntime
|
||||
) -> ScoredCandidate | None:
|
||||
"""Select best candidate, optionally using AI ranking for multiple candidates."""
|
||||
if not scored:
|
||||
return None
|
||||
|
||||
if len(scored) > 1:
|
||||
ranking = runtime.plugin.rank_candidates(scored, context, trace_id=runtime.trace_id)
|
||||
if ranking and ranking[0] < len(scored):
|
||||
return scored[ranking[0]]
|
||||
|
||||
return select_best(scored)
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
"""Shared types and utilities for optimization strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import threading # noqa: TC003 - used at runtime by OptimizationRuntime.is_cancelled()
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from codeflash_core.models import TestDiff
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from codeflash_core.config import CoreConfig, TestConfig
|
||||
from codeflash_core.models import (
|
||||
Candidate,
|
||||
CodeContext,
|
||||
FunctionToOptimize,
|
||||
GeneratedTestSuite,
|
||||
OptimizationResult,
|
||||
ScoredCandidate,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash_core.protocols import LanguagePlugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_CORRECT_CANDIDATES = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageSpec:
|
||||
"""Describes a single stage in the optimization pipeline."""
|
||||
|
||||
key: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationRuntime:
|
||||
"""Everything a strategy needs to orchestrate an optimization."""
|
||||
|
||||
plugin: LanguagePlugin
|
||||
config: CoreConfig
|
||||
test_config: TestConfig
|
||||
cancel_event: threading.Event
|
||||
output_comparator: Callable[..., bool] | None
|
||||
trace_id: str
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel_event.is_set()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OptimizationStrategy(Protocol):
|
||||
"""Protocol for optimization strategies.
|
||||
|
||||
Implement this to define a custom optimization pipeline for a single function.
|
||||
The Optimizer handles discovery, indexing, ranking, and the per-function loop;
|
||||
the strategy controls everything within a single function's optimization.
|
||||
"""
|
||||
|
||||
def optimize_function(
|
||||
self, function: FunctionToOptimize, runtime: OptimizationRuntime
|
||||
) -> OptimizationResult | None: ...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def cleanup_generated_tests(generated_tests: GeneratedTestSuite | None) -> None:
|
||||
"""Remove generated test files from disk."""
|
||||
if not generated_tests:
|
||||
return
|
||||
for tf in generated_tests.test_files:
|
||||
for path in (tf.behavior_test_path, tf.perf_test_path):
|
||||
with contextlib.suppress(OSError):
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def restore_test_snapshots(generated_tests: GeneratedTestSuite, snapshots: dict[int, tuple[str, str, str]]) -> None:
|
||||
"""Restore test file contents from pre-repair snapshots."""
|
||||
for i, tf in enumerate(generated_tests.test_files):
|
||||
if i not in snapshots:
|
||||
continue
|
||||
orig_source, orig_behavior, orig_perf = snapshots[i]
|
||||
tf.original_test_source = orig_source
|
||||
tf.behavior_test_source = orig_behavior
|
||||
tf.perf_test_source = orig_perf
|
||||
with contextlib.suppress(OSError):
|
||||
tf.behavior_test_path.write_text(orig_behavior, encoding="utf-8")
|
||||
with contextlib.suppress(OSError):
|
||||
tf.perf_test_path.write_text(orig_perf, encoding="utf-8")
|
||||
|
||||
|
||||
def compute_test_diffs(baseline: TestResults, candidate: TestResults) -> list[TestDiff]:
|
||||
"""Compute test outcome differences between baseline and candidate runs."""
|
||||
diffs: list[TestDiff] = []
|
||||
baseline_by_id = {o.test_id: o for o in baseline.outcomes}
|
||||
candidate_by_id = {o.test_id: o for o in candidate.outcomes}
|
||||
|
||||
for test_id, base_outcome in baseline_by_id.items():
|
||||
cand_outcome = candidate_by_id.get(test_id)
|
||||
if (
|
||||
cand_outcome is None
|
||||
or base_outcome.status != cand_outcome.status
|
||||
or base_outcome.output != cand_outcome.output
|
||||
):
|
||||
diffs.append(
|
||||
TestDiff(
|
||||
test_id=test_id,
|
||||
baseline_output=base_outcome.output,
|
||||
candidate_output=cand_outcome.output if cand_outcome else None,
|
||||
)
|
||||
)
|
||||
|
||||
return diffs
|
||||
|
||||
|
||||
def log_optimization_run(
|
||||
function: FunctionToOptimize,
|
||||
runtime: OptimizationRuntime,
|
||||
context: CodeContext | None = None,
|
||||
candidates: list[Candidate] | None = None,
|
||||
generated_tests: GeneratedTestSuite | None = None,
|
||||
scored: list[ScoredCandidate] | None = None,
|
||||
result: OptimizationResult | None = None,
|
||||
stage_reached: str = "",
|
||||
exit_reason: str = "",
|
||||
) -> None:
|
||||
"""Append a JSON record of the optimization run to codeflash_trace.log."""
|
||||
record: dict[str, Any] = {
|
||||
"trace_id": runtime.trace_id,
|
||||
"function": function.qualified_name,
|
||||
"file": str(function.file_path),
|
||||
"stage_reached": stage_reached,
|
||||
"exit_reason": exit_reason,
|
||||
}
|
||||
|
||||
if context:
|
||||
record["context"] = {
|
||||
"target_code_lines": len(context.target_code.splitlines()),
|
||||
"read_only_context_lines": len(context.read_only_context.splitlines()) if context.read_only_context else 0,
|
||||
"imports": len(context.imports),
|
||||
"helpers": [{"name": h.qualified_name, "file": str(h.file_path)} for h in context.helper_functions],
|
||||
}
|
||||
|
||||
record["candidates_count"] = len(candidates) if candidates else 0
|
||||
if candidates:
|
||||
record["candidates"] = [{"id": c.candidate_id, "source": c.source} for c in candidates]
|
||||
|
||||
record["tests_count"] = len(generated_tests.test_files) if generated_tests and generated_tests.test_files else 0
|
||||
|
||||
if scored:
|
||||
record["evaluation"] = [
|
||||
{
|
||||
"id": s.candidate.candidate_id,
|
||||
"source": s.candidate.source,
|
||||
"speedup": s.speedup,
|
||||
"score": s.score,
|
||||
"passed": s.test_results.passed,
|
||||
}
|
||||
for s in scored
|
||||
]
|
||||
|
||||
if result:
|
||||
record["result"] = {
|
||||
"speedup": result.speedup,
|
||||
"candidate_id": result.candidate.candidate_id,
|
||||
"explanation": result.explanation,
|
||||
"diff": result.diff,
|
||||
}
|
||||
|
||||
log_path = runtime.config.project_root / "codeflash_trace.log"
|
||||
try:
|
||||
with log_path.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record) + "\n")
|
||||
except OSError:
|
||||
logger.debug("Failed to write optimization trace log", exc_info=True)
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
from codeflash_core.telemetry.posthog_cf import PostHogClient
|
||||
from codeflash_core.telemetry.sentry_cf import init_sentry
|
||||
|
||||
__all__ = ["PostHogClient", "init_sentry"]
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostHogClient:
|
||||
"""Wrapper around the PostHog analytics client."""
|
||||
|
||||
instance: PostHogClient | None = None
|
||||
|
||||
def __init__(self, api_key: str, enabled: bool = True) -> None:
|
||||
self.enabled = enabled and bool(api_key)
|
||||
self.ph: Any = None
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
from posthog import Posthog
|
||||
|
||||
self.ph = Posthog(api_key, host="https://us.i.posthog.com")
|
||||
except Exception:
|
||||
logger.debug("Failed to initialize PostHog", exc_info=True)
|
||||
self.enabled = False
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, api_key: str, enabled: bool = True) -> PostHogClient:
|
||||
cls.instance = cls(api_key, enabled=enabled)
|
||||
return cls.instance
|
||||
|
||||
def capture(self, distinct_id: str, event: str, properties: dict[str, Any] | None = None) -> None:
|
||||
if not self.enabled or self.ph is None:
|
||||
return
|
||||
try:
|
||||
self.ph.capture(distinct_id, event, properties=properties or {})
|
||||
except Exception:
|
||||
logger.debug("PostHog capture failed", exc_info=True)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.ph is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self.ph.shutdown()
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_sentry(dsn: str, enabled: bool = True) -> None:
|
||||
"""Initialize Sentry error tracking."""
|
||||
if not enabled or not dsn:
|
||||
return
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
|
||||
sentry_sdk.init(dsn=dsn, traces_sample_rate=0.0)
|
||||
except Exception:
|
||||
logger.debug("Failed to initialize Sentry", exc_info=True)
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from codeflash_core.ui.console import (
|
||||
code_print,
|
||||
console,
|
||||
logger,
|
||||
paneled_text,
|
||||
progress_bar,
|
||||
setup_logging,
|
||||
test_files_progress_bar,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"code_print",
|
||||
"console",
|
||||
"logger",
|
||||
"paneled_text",
|
||||
"progress_bar",
|
||||
"setup_logging",
|
||||
"test_files_progress_bar",
|
||||
]
|
||||
|
|
@ -1,171 +0,0 @@
|
|||
"""Rich-based console UI for codeflash_core.
|
||||
|
||||
Provides spinners, progress bars, panels, code display, and logging.
|
||||
No LSP or subagent concerns — this is the core CLI output layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from itertools import cycle
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from rich.console import Console
|
||||
from rich.highlighter import NullHighlighter
|
||||
from rich.logging import RichHandler
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskID,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
from rich.text import Text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Console and logging
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
console = Console(highlighter=NullHighlighter())
|
||||
logger = logging.getLogger("codeflash")
|
||||
|
||||
|
||||
def setup_logging(level: int = logging.INFO) -> None:
|
||||
"""Configure root logger with Rich handler. Call from CLI, not at import time."""
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
handlers=[
|
||||
RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=False,
|
||||
highlighter=NullHighlighter(),
|
||||
console=console,
|
||||
show_path=False,
|
||||
show_time=False,
|
||||
)
|
||||
],
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Spinners
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SPINNER_TYPES = [
|
||||
"dots",
|
||||
"dots2",
|
||||
"dots3",
|
||||
"dots4",
|
||||
"dots5",
|
||||
"line",
|
||||
"line2",
|
||||
"arc",
|
||||
"circle",
|
||||
"star",
|
||||
"star2",
|
||||
"moon",
|
||||
"bouncingBar",
|
||||
"bouncingBall",
|
||||
"flip",
|
||||
"growVertical",
|
||||
"growHorizontal",
|
||||
"balloon",
|
||||
"noise",
|
||||
"bounce",
|
||||
"point",
|
||||
"layer",
|
||||
"betaWave",
|
||||
]
|
||||
|
||||
spinner_cycle = cycle(SPINNER_TYPES)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dummy types for fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DummyTask:
|
||||
def __init__(self) -> None:
|
||||
self.id: TaskID = TaskID(0)
|
||||
|
||||
|
||||
class DummyProgress:
|
||||
def advance(self, task_id: TaskID, advance: int = 1) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress bars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
progress_bar_active = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]:
|
||||
"""Spinner with elapsed time. Avoids nesting Rich Live displays."""
|
||||
global progress_bar_active
|
||||
|
||||
if progress_bar_active:
|
||||
yield DummyTask().id
|
||||
return
|
||||
|
||||
progress_bar_active = True
|
||||
try:
|
||||
progress = Progress(
|
||||
SpinnerColumn(next(spinner_cycle)),
|
||||
*Progress.get_default_columns(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
transient=transient,
|
||||
)
|
||||
task = progress.add_task(message, total=None)
|
||||
with progress:
|
||||
yield task
|
||||
finally:
|
||||
progress_bar_active = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
|
||||
"""Progress bar with M/N counter for test files."""
|
||||
with Progress(
|
||||
SpinnerColumn(next(spinner_cycle)),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(complete_style="cyan", finished_style="green", pulse_style="yellow"),
|
||||
MofNCompleteColumn(),
|
||||
TimeElapsedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
transient=True,
|
||||
) as progress:
|
||||
task_id = progress.add_task(description, total=total)
|
||||
yield progress, task_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Display helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def paneled_text(text: str, *, title: str = "", border_style: str = "cyan") -> None:
|
||||
"""Print text inside a bordered panel."""
|
||||
console.print(Panel(Text(text), title=title or None, border_style=border_style))
|
||||
|
||||
|
||||
def code_print(code_str: str, *, language: str = "python") -> None:
|
||||
"""Print code with syntax highlighting."""
|
||||
from rich.syntax import Syntax
|
||||
|
||||
console.rule()
|
||||
console.print(Syntax(code_str, language, line_numbers=True, theme="github-dark"))
|
||||
console.rule()
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_core.models import TestOutcome, TestResults
|
||||
|
||||
# Type alias for a custom output comparator function.
|
||||
OutputComparator = Callable[[object, object], bool]
|
||||
|
||||
|
||||
def is_equivalent(baseline: TestResults, candidate: TestResults, comparator: OutputComparator | None = None) -> bool:
|
||||
"""Check if candidate test results are equivalent to baseline.
|
||||
|
||||
Both must pass, have the same number of outcomes, and each outcome
|
||||
must match on test_id, status, and output.
|
||||
|
||||
If *comparator* is provided it is used instead of ``==`` to compare
|
||||
captured test outputs.
|
||||
"""
|
||||
if not baseline.passed or not candidate.passed:
|
||||
return False
|
||||
|
||||
if len(baseline.outcomes) != len(candidate.outcomes):
|
||||
return False
|
||||
|
||||
baseline_by_id = {o.test_id: o for o in baseline.outcomes}
|
||||
candidate_by_id = {o.test_id: o for o in candidate.outcomes}
|
||||
|
||||
if baseline_by_id.keys() != candidate_by_id.keys():
|
||||
return False
|
||||
|
||||
return all(
|
||||
outcomes_match(baseline_by_id[test_id], candidate_by_id[test_id], comparator=comparator)
|
||||
for test_id in baseline_by_id
|
||||
)
|
||||
|
||||
|
||||
def outcomes_match(baseline: TestOutcome, candidate: TestOutcome, comparator: OutputComparator | None = None) -> bool:
|
||||
if baseline.status != candidate.status:
|
||||
return False
|
||||
if baseline.output is not None and candidate.output is not None:
|
||||
return compare_outputs(baseline.output, candidate.output, comparator=comparator)
|
||||
return True
|
||||
|
||||
|
||||
def compare_outputs(
|
||||
baseline_output: object, candidate_output: object, comparator: OutputComparator | None = None
|
||||
) -> bool:
|
||||
"""Compare return-value outputs using *comparator*, falling back to ``==``."""
|
||||
if comparator is not None:
|
||||
return comparator(baseline_output, candidate_output)
|
||||
return baseline_output == candidate_output
|
||||
Loading…
Reference in a new issue