Merge pull request #1915 from codeflash-ai/cf-remove-codeflash-core

chore: remove src/codeflash_core package
This commit is contained in:
Kevin Turcios 2026-03-27 04:21:20 -05:00 committed by GitHub
commit 1c5404f156
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1 additions and 2905 deletions

View file

@ -104,7 +104,7 @@ tests = [
]
[tool.hatch.build.targets.sdist]
include = ["codeflash", "src/codeflash_core"]
include = ["codeflash"]
exclude = [
"docs/*",
"experiments/*",

View file

@ -1,32 +0,0 @@
from codeflash_core.config import CoreConfig, TestConfig
from codeflash_core.models import (
BenchmarkResults,
Candidate,
CodeContext,
FunctionToOptimize,
HelperFunction,
OptimizationResult,
ScoredCandidate,
TestOutcome,
TestOutcomeStatus,
TestResults,
)
from codeflash_core.optimizer import Optimizer
from codeflash_core.protocols import LanguagePlugin
__all__ = [
"BenchmarkResults",
"Candidate",
"CodeContext",
"CoreConfig",
"FunctionToOptimize",
"HelperFunction",
"LanguagePlugin",
"OptimizationResult",
"Optimizer",
"ScoredCandidate",
"TestConfig",
"TestOutcome",
"TestOutcomeStatus",
"TestResults",
]

View file

@ -1,3 +0,0 @@
from codeflash_core.ai.client import AIClient
__all__ = ["AIClient"]

View file

@ -1,54 +0,0 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import requests
if TYPE_CHECKING:
from codeflash_core.config import AIConfig
from codeflash_core.models import Candidate, CodeContext
logger = logging.getLogger(__name__)
class AIClient:
"""Client for the Codeflash AI optimization service."""
def __init__(self, config: AIConfig) -> None:
self.base_url = config.base_url.rstrip("/")
self.api_key = config.api_key
self.timeout = config.timeout
self.session = requests.Session()
if self.api_key:
self.session.headers["Authorization"] = f"Bearer {self.api_key}"
def get_candidates(self, context: CodeContext) -> list[Candidate]:
"""Request optimization candidates from the AI service."""
from codeflash_core.models import Candidate
payload = {
"function_name": context.target_function.qualified_name,
"source_code": context.target_code,
"helper_functions": [
{"name": h.qualified_name, "source_code": h.source_code} for h in context.helper_functions
],
"read_only_context": context.read_only_context,
"imports": context.imports,
}
try:
resp = self.session.post(f"{self.base_url}/ai/optimize", json=payload, timeout=self.timeout)
resp.raise_for_status()
data = resp.json()
except requests.RequestException:
logger.exception("AI service request failed")
return []
candidates = []
for item in data.get("candidates", []):
candidates.append(Candidate(code=item["code"], explanation=item.get("explanation", "")))
return candidates
def close(self) -> None:
self.session.close()

View file

@ -1,201 +0,0 @@
"""CLI entry point for codeflash."""
from __future__ import annotations
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash_core.strategy_utils import OptimizationStrategy
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="cfnext", description="Optimize Python code with AI.")
parser.add_argument("--version", action="store_true", help="Print the version and exit.")
parser.add_argument(
"--show-config", action="store_true", help="Show current or auto-detected configuration and exit."
)
sub = parser.add_subparsers(dest="command")
opt = sub.add_parser("optimize", help="Optimize functions in the given files.")
opt.add_argument("files", nargs="*", help="Python files to optimize.")
opt.add_argument("--file", dest="target_file", default=None, help="Single file to optimize.")
opt.add_argument("--function", dest="target_function", default=None, help="Specific function name to optimize.")
opt.add_argument("--all", action="store_true", dest="optimize_all", help="Optimize all files in module-root.")
opt.add_argument("--project-root", type=Path, default=None, help="Override project root directory.")
opt.add_argument("--module-root", default=None, help="Override module root (relative to project root).")
opt.add_argument("--tests-root", default=None, help="Override tests root (relative to project root).")
opt.add_argument("--benchmarks-root", default=None, help="Override benchmarks root (relative to project root).")
opt.add_argument("--api-key", default=None, help="Codeflash API key (default: $CODEFLASH_API_KEY).")
opt.add_argument("--effort", choices=["low", "medium", "high"], default=None, help="Effort level for optimization.")
opt.add_argument(
"--server",
choices=["local", "prod"],
default=None,
help="AI service server: 'local' for localhost:8000, 'prod' for app.codeflash.ai.",
)
opt.add_argument(
"--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact."
)
opt.add_argument("--no-pr", action="store_true", help="Do not create a PR, only update code locally.")
opt.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts.")
opt.add_argument(
"--formatter-cmds", nargs="*", default=None, help="Override formatter commands (each applied to $file)."
)
opt.add_argument("--pytest-cmd", default=None, help="Override the pytest command to use.")
opt.add_argument("--disable-telemetry", action="store_true", help="Disable telemetry.")
opt.add_argument(
"--strategy", choices=["default"], default="default", help="Optimization strategy to use (default: default)."
)
opt.add_argument("-v", "--verbose", action="store_true", help="Enable debug logging.")
return parser
def collect_files(config_project_root: Path, config_module_root: str) -> list[Path]:
"""Collect all .py files under the module root."""
module_dir = config_project_root / config_module_root if config_module_root else config_project_root
if not module_dir.is_dir():
return []
return sorted(
p for p in module_dir.rglob("*.py") if not p.name.startswith("test_") and not p.name.endswith("_test.py")
)
SERVER_URLS = {"local": "http://localhost:8000", "prod": "https://app.codeflash.ai"}
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
# -- Top-level flags (no subcommand needed) ------------------------------
if args.version:
from importlib.metadata import version
print(f"cfnext {version('codeflash-core')}")
return 0
# -- Config (needed for --show-config and optimize) ----------------------
from codeflash_core.config import CoreConfig
start_dir = getattr(args, "project_root", None) or Path.cwd()
config = CoreConfig.find_and_load(start_dir)
if args.show_config:
import json
info = {
"project_root": str(config.project_root),
"module_root": config.module_root,
"tests_root": config.tests_root,
"benchmarks_root": config.benchmarks_root,
"effort": config.effort,
"formatter_cmds": config.formatter_cmds,
"ignore_paths": config.ignore_paths,
"ai_base_url": config.ai.base_url,
"disable_telemetry": config.disable_telemetry,
}
print(json.dumps(info, indent=2))
return 0
if args.command is None:
parser.print_help()
return 0
if args.command != "optimize":
parser.print_help()
return 1
# -- Logging -------------------------------------------------------------
from codeflash_core.ui import setup_logging
setup_logging(level=logging.DEBUG if args.verbose else logging.WARNING)
# CLI overrides
if args.project_root:
config.project_root = args.project_root.resolve()
if args.module_root is not None:
config.module_root = args.module_root
if args.tests_root is not None:
config.tests_root = args.tests_root
if args.benchmarks_root is not None:
config.benchmarks_root = args.benchmarks_root
if args.effort is not None:
config.effort = args.effort
if args.server is not None:
config.ai.base_url = SERVER_URLS[args.server]
if args.formatter_cmds is not None:
config.formatter_cmds = args.formatter_cmds
if args.disable_telemetry:
config.disable_telemetry = True
# API key: CLI flag > env var > config file
api_key = args.api_key or os.environ.get("CODEFLASH_API_KEY", "") or config.ai.api_key
if not api_key:
print("Error: No API key provided. Set CODEFLASH_API_KEY or pass --api-key.")
return 1
config.ai.api_key = api_key
# -- Telemetry -----------------------------------------------------------
from codeflash_core.telemetry import PostHogClient, init_sentry
if not config.disable_telemetry:
PostHogClient.initialize(config.telemetry.posthog_api_key, enabled=config.telemetry.enabled)
init_sentry(config.telemetry.sentry_dsn, enabled=config.telemetry.enabled)
# -- Resolve files -------------------------------------------------------
if args.target_file:
files = [Path(args.target_file).resolve()]
elif args.optimize_all:
files = collect_files(config.project_root, config.module_root)
elif args.files:
files = [Path(f).resolve() for f in args.files]
else:
print("Error: Provide files to optimize, use --file, or use --all.")
return 1
if not files:
print("Error: No Python files found.")
return 1
# -- Build plugin & validate environment ---------------------------------
from codeflash_core.optimizer import Optimizer
try:
from codeflash.plugin import PythonPlugin
except ImportError:
print("Error: codeflash package not installed. Install it to use the Python plugin.")
return 1
plugin = PythonPlugin(config.project_root)
if hasattr(plugin, "validate_environment") and not plugin.validate_environment(config):
return 1
# -- Resolve strategy ----------------------------------------------------
strategy = _resolve_strategy(args.strategy)
optimizer = Optimizer(config, plugin, strategy=strategy)
results = optimizer.run(files, function_filter=args.target_function)
# -- Shutdown telemetry --------------------------------------------------
if PostHogClient.instance is not None:
PostHogClient.instance.shutdown()
return 0 if results else 2
def _resolve_strategy(name: str) -> OptimizationStrategy:
"""Return the strategy instance for the given CLI name."""
from codeflash_core.strategy import DefaultStrategy
return DefaultStrategy()
if __name__ == "__main__":
sys.exit(main())

View file

@ -1,239 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any
import tomlkit
class EffortLevel(str, Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
class EffortKeys(str, Enum):
N_OPTIMIZER_CANDIDATES = "N_OPTIMIZER_CANDIDATES"
N_OPTIMIZER_LP_CANDIDATES = "N_OPTIMIZER_LP_CANDIDATES"
N_GENERATED_TESTS = "N_GENERATED_TESTS"
MAX_CODE_REPAIRS_PER_TRACE = "MAX_CODE_REPAIRS_PER_TRACE"
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = "REPAIR_UNMATCHED_PERCENTAGE_LIMIT"
TOP_VALID_CANDIDATES_FOR_REFINEMENT = "TOP_VALID_CANDIDATES_FOR_REFINEMENT"
ADAPTIVE_OPTIMIZATION_THRESHOLD = "ADAPTIVE_OPTIMIZATION_THRESHOLD"
MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE = "MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE"
HIGH_EFFORT_TOP_N = 15
MAX_TEST_REPAIR_CYCLES = 2
EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = {
EffortKeys.N_OPTIMIZER_CANDIDATES.value: {EffortLevel.LOW: 3, EffortLevel.MEDIUM: 5, EffortLevel.HIGH: 6},
EffortKeys.N_OPTIMIZER_LP_CANDIDATES.value: {EffortLevel.LOW: 4, EffortLevel.MEDIUM: 6, EffortLevel.HIGH: 7},
EffortKeys.N_GENERATED_TESTS.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 2, EffortLevel.HIGH: 2},
EffortKeys.MAX_CODE_REPAIRS_PER_TRACE.value: {EffortLevel.LOW: 2, EffortLevel.MEDIUM: 3, EffortLevel.HIGH: 5},
EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT.value: {
EffortLevel.LOW: 0.2,
EffortLevel.MEDIUM: 0.3,
EffortLevel.HIGH: 0.4,
},
EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT.value: {
EffortLevel.LOW: 2,
EffortLevel.MEDIUM: 3,
EffortLevel.HIGH: 4,
},
EffortKeys.ADAPTIVE_OPTIMIZATION_THRESHOLD.value: {EffortLevel.LOW: 0, EffortLevel.MEDIUM: 0, EffortLevel.HIGH: 2},
EffortKeys.MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE.value: {
EffortLevel.LOW: 0,
EffortLevel.MEDIUM: 0,
EffortLevel.HIGH: 4,
},
}
def get_effort_value(key: EffortKeys, effort: EffortLevel | str) -> Any:
"""Look up an effort-dependent parameter value."""
if isinstance(effort, str):
effort = EffortLevel(effort)
return EFFORT_VALUES[key.value][effort]
@dataclass
class TestConfig:
tests_root: Path
project_root: Path
test_command: str = ""
timeout: float = 60.0
tests_project_rootdir: Path | None = None
@dataclass
class TelemetryConfig:
enabled: bool = True
posthog_api_key: str = ""
sentry_dsn: str = ""
@dataclass
class AIConfig:
base_url: str = "https://app.codeflash.ai"
api_key: str = ""
timeout: float = 120.0
CONFIG_FILE_NAMES = ("pyproject.toml", "codeflash.toml")
GLOB_PATTERN_CHARS = frozenset("*?[")
def is_glob_pattern(path_str: str) -> bool:
"""Check if a path string contains glob pattern characters."""
return any(char in path_str for char in GLOB_PATTERN_CHARS)
def normalize_ignore_paths(paths: list[str], base_path: Path | None = None) -> list[Path]:
if base_path is None:
base_path = Path.cwd()
base_path = base_path.resolve()
normalized: set[Path] = set()
for path_str in paths:
if not path_str:
continue
path_str = str(path_str)
if is_glob_pattern(path_str):
path_str = path_str.removeprefix("./")
if path_str.startswith("/"):
path_str = path_str.lstrip("/")
for matched_path in base_path.glob(path_str):
normalized.add(matched_path.resolve())
else:
path_obj = Path(path_str)
if not path_obj.is_absolute():
path_obj = base_path / path_obj
if path_obj.exists():
normalized.add(path_obj.resolve())
return list(normalized)
@dataclass
class CoreConfig:
project_root: Path = field(default_factory=Path.cwd)
module_root: str = ""
tests_root: str = "tests"
benchmarks_root: str = ""
ignore_paths: list[str] = field(default_factory=list)
formatter_cmds: list[str] = field(default_factory=list)
disable_telemetry: bool = False
effort: str = "medium"
create_pr: bool = False
pytest_cmd: str = "pytest"
disable_imports_sorting: bool = False
git_remote: str = "origin"
override_fixtures: bool = False
ai: AIConfig = field(default_factory=AIConfig)
telemetry: TelemetryConfig = field(default_factory=TelemetryConfig)
@property
def resolved_ignore_paths(self) -> list[Path]:
base = (self.project_root / self.module_root) if self.module_root else self.project_root
return normalize_ignore_paths(self.ignore_paths, base_path=base)
@classmethod
def from_toml(cls, path: Path) -> CoreConfig:
"""Load config from a toml file containing a [tool.codeflash] section.
Works with both pyproject.toml and codeflash.toml.
"""
with path.open(encoding="utf-8") as f:
data = tomlkit.load(f)
cf: dict[str, Any] = data.get("tool", {}).get("codeflash", {})
if not cf:
return cls(project_root=path.parent)
config = cls(
project_root=path.parent,
module_root=cf.get("module-root", ""),
tests_root=cf.get("tests-root", "tests"),
benchmarks_root=cf.get("benchmarks-root", ""),
ignore_paths=cf.get("ignore-paths", []),
formatter_cmds=cf.get("formatter-cmds", []),
disable_telemetry=cf.get("disable_telemetry", False),
effort=cf.get("effort", "medium"),
create_pr=cf.get("create-pr", False),
pytest_cmd=cf.get("pytest-cmd", "pytest"),
disable_imports_sorting=cf.get("disable-imports-sorting", False),
git_remote=cf.get("git-remote", "origin"),
override_fixtures=cf.get("override-fixtures", False),
)
if config.module_root and not (config.project_root / config.module_root).exists():
msg = f"module-root '{config.module_root}' does not exist under {config.project_root}"
raise FileNotFoundError(msg)
return config
@classmethod
def find_and_load(cls, start: Path | None = None) -> CoreConfig:
"""Walk up from start (default cwd) looking for a config file.
Searches for pyproject.toml and codeflash.toml in each directory
up to the filesystem root. Returns a default config if nothing is found.
"""
path = cls.find_config_file(start or Path.cwd())
if path is None:
return cls(project_root=start or Path.cwd())
return cls.from_toml(path)
@staticmethod
def find_config_file(start: Path) -> Path | None:
"""Walk up directories looking for a config file with a [tool.codeflash] section."""
current = start.resolve()
while True:
for name in CONFIG_FILE_NAMES:
candidate = current / name
if candidate.is_file():
try:
with candidate.open(encoding="utf-8") as f:
data = tomlkit.load(f)
if data.get("tool", {}).get("codeflash"):
return candidate
except Exception:
continue
parent = current.parent
if parent == current:
break
current = parent
return None
def resolve_test_config(self) -> TestConfig:
tests_root = self.project_root / self.tests_root
if not tests_root.is_dir():
msg = f"tests-root '{self.tests_root}' does not exist under {self.project_root}"
raise FileNotFoundError(msg)
# Compute tests_project_rootdir by walking up from tests_root to find pyproject.toml
# This is the pytest rootdir used for resolving test module paths
tests_project_rootdir = self.find_tests_project_rootdir(tests_root)
return TestConfig(
tests_root=tests_root,
project_root=self.project_root,
tests_project_rootdir=tests_project_rootdir,
test_command=self.pytest_cmd,
)
def find_tests_project_rootdir(self, tests_root: Path) -> Path:
"""Walk up from tests_root looking for a directory containing pyproject.toml or codeflash.toml."""
current = tests_root.resolve()
while current != current.parent:
for name in CONFIG_FILE_NAMES:
if (current / name).is_file():
return current
current = current.parent
return self.project_root

View file

@ -1,21 +0,0 @@
from codeflash_core.danom.new_type import new_type
from codeflash_core.danom.result import Err, Ok, Result
from codeflash_core.danom.safe import safe, safe_method
from codeflash_core.danom.stream import Stream
from codeflash_core.danom.utils import all_of, any_of, compose, identity, invert, none_of
__all__ = [
"Err",
"Ok",
"Result",
"Stream",
"all_of",
"any_of",
"compose",
"identity",
"invert",
"new_type",
"none_of",
"safe",
"safe_method",
]

View file

@ -1,97 +0,0 @@
from __future__ import annotations
import inspect
from collections.abc import Sequence
from functools import wraps
from typing import TYPE_CHECKING, TypeVar
import attrs
if TYPE_CHECKING:
from collections.abc import Callable
from typing_extensions import ParamSpec, Self
P = ParamSpec("P")
C = TypeVar("C", bound=Callable[P, object])
T = TypeVar("T")
def new_type(
name: str,
base_type: type,
validators: Callable | Sequence[Callable] | None = None,
converters: Callable | Sequence[Callable] | None = None,
*,
frozen: bool = True,
) -> type:
kwargs = _callables_to_kwargs(base_type, validators, converters)
@attrs.define(frozen=frozen, eq=True, hash=frozen)
class _Wrapper:
inner: T = attrs.field(**kwargs) # type: ignore[no-matching-overload]
def map(self, func: Callable[[T], T]) -> Self:
return self.__class__(func(self.inner)) # type: ignore[invalid-argument-type]
locals().update(_create_forward_methods(base_type))
_Wrapper.__name__ = name
_Wrapper.__qualname__ = name
return _Wrapper
def _create_forward_methods(base_type: type) -> dict[str, Callable]:
methods: dict[str, Callable] = {}
for attr_name, _ in inspect.getmembers(base_type, inspect.isroutine):
if attr_name.startswith("_"):
continue
def make_forwarder(name: str) -> Callable:
def method(self, *args: tuple, **kwargs: dict) -> T:
return getattr(self.inner, name)(*args, **kwargs)
method.__name__ = name
method.__doc__ = getattr(base_type, name).__doc__
return method
methods[attr_name] = make_forwarder(attr_name)
return methods
def _callables_to_kwargs(
base_type: type, validators: Callable | Sequence[Callable] | None, converters: Callable | Sequence[Callable] | None
) -> dict[str, Sequence[Callable]]:
kwargs = {"validator": [attrs.validators.instance_of(base_type)], "converter": []}
kwargs["validator"] += [_validate_bool_func(fn) for fn in _to_list(validators)]
kwargs["converter"] += _to_list(converters)
return {k: v for k, v in kwargs.items() if v}
def _validate_bool_func(bool_fn: Callable[[T], bool]) -> Callable[[attrs.AttrsInstance, attrs.Attribute, T], None]:
if not callable(bool_fn):
raise TypeError("provided boolean function must be callable")
@wraps(bool_fn)
def wrapper(_instance: attrs.AttrsInstance, attribute: attrs.Attribute, value: T) -> None:
if not bool_fn(value):
msg = f"{attribute.name} does not return True for the given boolean function, received `{value}`."
raise ValueError(msg)
return wrapper
def _to_list(value: C | Sequence[C] | None) -> list[C]:
if value is None:
return []
if callable(value):
return [value] # type: ignore[invalid-return-type]
if isinstance(value, Sequence) and not all(callable(fn) for fn in value):
msg = f"Given items are not all callable: {value = }"
raise TypeError(msg)
return list(value)

View file

@ -1,157 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
try:
from typing import Never # type: ignore[unresolved-import]
except ImportError:
from typing import NoReturn as Never
import attrs
from attrs.validators import instance_of
T_co = TypeVar("T_co", covariant=True)
U_co = TypeVar("U_co", covariant=True)
E_co = TypeVar("E_co", bound=object, covariant=True)
F_co = TypeVar("F_co", bound=object, covariant=True)
if TYPE_CHECKING:
from collections.abc import Callable
from types import TracebackType
from typing_extensions import Concatenate, ParamSpec, Self
P = ParamSpec("P")
Mappable = Callable[Concatenate[T_co, P], U_co]
Bindable = Callable[Concatenate[T_co, P], "Result[U_co, E_co]"]
@attrs.define(frozen=True)
class Result(ABC, Generic[T_co, E_co]):
"""`Result` monad. Consists of `Ok` and `Err` for successful and failed operations respectively.
Each monad is a frozen instance to prevent further mutation.
"""
@classmethod
def unit(cls, inner: T_co) -> Ok[T_co]:
return Ok(inner)
@abstractmethod
def is_ok(self) -> bool: ...
@abstractmethod
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
@abstractmethod
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
@abstractmethod
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
@abstractmethod
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: ...
@abstractmethod
def unwrap(self) -> T_co: ...
@staticmethod
def result_is_ok(result: Result[T_co, E_co]) -> bool:
return result.is_ok()
@staticmethod
def result_unwrap(result: Result[T_co, E_co]) -> T_co:
return result.unwrap()
@attrs.define(frozen=True, hash=True)
class Ok(Result[T_co, Never]):
inner: Any = attrs.field(default=None)
def is_ok(self) -> Literal[True]:
return True
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[U_co]:
return Ok(func(self.inner, *args, **kwargs))
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self:
return self
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.inner, *args, **kwargs)
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self:
return self
def unwrap(self) -> T_co:
return self.inner
SafeArgs = tuple[tuple[Any, ...], dict[str, Any]]
SafeMethodArgs = tuple[object, tuple[Any, ...], dict[str, Any]]
@attrs.define(frozen=True)
class Err(Result[Never, E_co]):
error: Any = attrs.field(default=None)
input_args: tuple[()] | SafeArgs | SafeMethodArgs = attrs.field(
default=(), validator=instance_of(tuple), repr=False
)
traceback: str = attrs.field(default="", validator=instance_of(str))
details: list[dict[str, Any]] = attrs.field(factory=list, init=False, repr=False)
def __attrs_post_init__(self) -> None:
if isinstance(self.error, Exception):
object.__setattr__(self, "details", self._extract_details(self.error.__traceback__))
def _extract_details(self, tb: TracebackType | None) -> list[dict[str, Any]]:
trace_info = []
while tb:
frame = tb.tb_frame
trace_info.append(
{
"file": frame.f_code.co_filename,
"func": frame.f_code.co_name,
"line_no": tb.tb_lineno,
"locals": frame.f_locals,
}
)
tb = tb.tb_next
return trace_info
def is_ok(self) -> Literal[False]:
return False
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self:
return self
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[F_co]:
return Err(func(self.error, *args, **kwargs))
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self:
return self
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.error, *args, **kwargs)
def unwrap(self) -> T_co:
if isinstance(self.error, Exception):
raise self.error
msg = f"Err does not have a caught error to raise: {self.error = }"
raise ValueError(msg)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Err):
return False
return all(
(
type(self.error) is type(other.error),
str(self.error) == str(other.error),
self.input_args == other.input_args,
)
)
def __hash__(self) -> int:
return hash(f"{type(self.error)}{self.error}{self.input_args}")

View file

@ -1,42 +0,0 @@
from __future__ import annotations
import functools
import traceback
from typing import TYPE_CHECKING
from codeflash_core.danom.result import Err, Ok
if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from codeflash_core.danom.result import Result
T = TypeVar("T")
P = ParamSpec("P")
U = TypeVar("U")
E = TypeVar("E")
def safe(func: Callable[P, U]) -> Callable[P, Result[U, Exception]]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[U, Exception]:
try:
return Ok(func(*args, **kwargs))
except Exception as e:
return Err(error=e, input_args=(args, kwargs), traceback=traceback.format_exc())
return wrapper
def safe_method(func: Callable[Concatenate[T, P], U]) -> Callable[Concatenate[T, P], Result[U, Exception]]:
@functools.wraps(func)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> Result[U, Exception]:
try:
return Ok(func(self, *args, **kwargs))
except Exception as e:
return Err(error=e, input_args=(self, args, kwargs), traceback=traceback.format_exc())
return wrapper

View file

@ -1,204 +0,0 @@
from __future__ import annotations
import asyncio
import itertools
import os
from abc import ABC, abstractmethod
from collections.abc import Iterable
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from copy import deepcopy
from enum import Enum, auto
from functools import reduce
from typing import TYPE_CHECKING, TypeVar, Union, cast
try:
from itertools import batched # type: ignore[attr-defined]
except ImportError:
from itertools import islice
def batched(iterable, n, *, strict=False): # noqa: ANN201
if n < 1:
raise ValueError("n must be at least one")
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
if strict and len(batch) != n:
raise ValueError("batched(): incomplete batch")
yield batch
import attrs
T = TypeVar("T")
U = TypeVar("U")
E = TypeVar("E")
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Generator
MapFn = Callable[[T], U]
FilterFn = Callable[[T], bool]
TapFn = Callable[[T], None]
AsyncMapFn = Callable[[T], Awaitable[U]]
AsyncFilterFn = Callable[[T], Awaitable[bool]]
AsyncTapFn = Callable[[T], Awaitable[None]]
StreamFn = Union[MapFn, FilterFn, TapFn]
AsyncStreamFn = Union[AsyncMapFn, AsyncFilterFn, AsyncTapFn]
PlannedOps = tuple[str, StreamFn]
AsyncPlannedOps = tuple[str, AsyncStreamFn]
@attrs.define(frozen=True)
class _BaseStream(ABC):
seq: tuple = attrs.field(validator=attrs.validators.instance_of(tuple))
ops: tuple = attrs.field(default=(), validator=attrs.validators.instance_of(tuple), repr=False)
@classmethod
@abstractmethod
def from_iterable(cls, it: Iterable) -> _BaseStream[T]: ...
@abstractmethod
def map(self, *fns: MapFn | AsyncMapFn) -> _BaseStream[T]: ...
@abstractmethod
def filter(self, *fns: FilterFn | AsyncFilterFn) -> _BaseStream[T]: ...
@abstractmethod
def tap(self, *fns: TapFn | AsyncTapFn) -> _BaseStream[T]: ...
@abstractmethod
def partition(self, fn: FilterFn) -> tuple[_BaseStream[T], _BaseStream[U]]: ...
@abstractmethod
def fold(self, initial: T, fn: Callable[[T, U], T], *, workers: int = 1, use_threads: bool = False) -> T: ...
@abstractmethod
def collect(self) -> tuple[U, ...]: ...
@abstractmethod
def par_collect(self, workers: int = 4, *, use_threads: bool = False) -> tuple[U, ...]: ...
@abstractmethod
async def async_collect(self) -> Awaitable[tuple[U, ...]]: ...
def __bool__(self) -> bool:
return bool(self.seq)
@attrs.define(frozen=True)
class Stream(_BaseStream):
@classmethod
def from_iterable(cls, it: Iterable) -> Stream[T]:
if not isinstance(it, Iterable):
it = [it]
return cls(seq=tuple(it))
def map(self, *fns: MapFn | AsyncMapFn) -> Stream[U]:
plan = (*self.ops, *tuple((_MAP, fn) for fn in fns))
object.__setattr__(self, "ops", plan)
return self
def filter(self, *fns: FilterFn | AsyncFilterFn) -> Stream[T]:
plan = (*self.ops, *tuple((_FILTER, fn) for fn in fns))
object.__setattr__(self, "ops", plan)
return self
def tap(self, *fns: TapFn | AsyncTapFn) -> Stream[T]:
plan = (*self.ops, *tuple((_TAP, fn) for fn in fns))
object.__setattr__(self, "ops", plan)
return self
def partition(self, fn: FilterFn, *, workers: int = 1, use_threads: bool = False) -> tuple[Stream[T], Stream[U]]:
if workers > 1:
seq_tuple = self.par_collect(workers=workers, use_threads=use_threads)
else:
seq_tuple = self.collect()
return (Stream(seq=tuple(x for x in seq_tuple if fn(x))), Stream(seq=tuple(x for x in seq_tuple if not fn(x))))
def fold(self, initial: T, fn: Callable[[T, U], T], *, workers: int = 1, use_threads: bool = False) -> T:
if workers > 1:
return reduce(fn, self.par_collect(workers=workers, use_threads=use_threads), initial)
return reduce(fn, self.collect(), initial)
def collect(self) -> tuple[U, ...]:
return tuple(_apply_fns(self.seq, self.ops))
def par_collect(self, workers: int = 4, *, use_threads: bool = False) -> tuple[U, ...]:
if workers == -1:
workers = (os.cpu_count() or 5) - 1
executor_cls = ThreadPoolExecutor if use_threads else ProcessPoolExecutor
batches = [(list(chunk), self.ops) for chunk in batched(self.seq, n=max(4, len(self.seq) // workers))]
with executor_cls(max_workers=workers) as ex:
return cast("tuple[U, ...]", tuple(itertools.chain.from_iterable(ex.map(_apply_fns_worker, batches))))
async def async_collect(self) -> Awaitable[tuple[U, ...]]:
if not self.ops:
return cast("Awaitable[tuple[U, ...]]", self.collect())
res = await asyncio.gather(*(_async_apply_fns(x, self.ops) for x in self.seq))
return cast("Awaitable[tuple[U, ...]]", tuple(elem for elem in res if elem != _Nothing.NOTHING))
_MAP = 0
_FILTER = 1
_TAP = 2
class _Nothing(Enum):
NOTHING = auto()
def _apply_fns_worker(args: tuple[tuple[T], tuple[PlannedOps, ...]]) -> tuple[T]:
seq, ops = args
return _par_apply_fns(seq, ops)
def _apply_fns(elements: tuple[T], ops: tuple[PlannedOps, ...]) -> Generator[T, None, None]:
for elem in elements:
valid = True
res = elem
for op, op_fn in ops:
if op == _MAP:
res = op_fn(res)
elif op == _FILTER and not op_fn(res):
valid = False
break
elif op == _TAP:
op_fn(deepcopy(res))
if valid:
yield res
def _par_apply_fns(elements: tuple[T], ops: tuple[PlannedOps, ...]) -> tuple[T]:
results = []
for elem in elements:
valid = True
res = elem
for op, op_fn in ops:
if op == _MAP:
res = op_fn(res)
elif op == _FILTER and not op_fn(res):
valid = False
break
elif op == _TAP:
op_fn(deepcopy(res))
if valid:
results.append(res)
return tuple(results)
async def _async_apply_fns(elem: T, ops: tuple[AsyncPlannedOps, ...]) -> T | _Nothing:
res = elem
for op, op_fn in ops:
if op == _MAP:
res = await op_fn(res)
elif op == _FILTER and not await op_fn(res):
return _Nothing.NOTHING
elif op == _TAP:
await op_fn(deepcopy(res))
return res

View file

@ -1,68 +0,0 @@
from __future__ import annotations
from functools import reduce
from operator import not_
from typing import TYPE_CHECKING, TypeVar
import attrs
T_co = TypeVar("T_co", covariant=True)
U_co = TypeVar("U_co", covariant=True)
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
Composable = Callable[[T_co], T_co | U_co]
Filterable = Callable[[T_co], bool]
@attrs.define(frozen=True, hash=True, eq=True)
class _Compose:
fns: Sequence[Composable]
def __call__(self, initial: T_co) -> T_co | U_co:
return reduce(_apply, self.fns, initial)
def _apply(value: T_co, fn: Composable) -> T_co | U_co:
return fn(value)
def compose(*fns: Composable) -> Composable:
return _Compose(fns)
@attrs.define(frozen=True, hash=True, eq=True)
class _AllOf:
fns: Sequence[Filterable]
def __call__(self, item: T_co) -> bool:
return all(fn(item) for fn in self.fns)
def all_of(*fns: Filterable) -> Filterable:
return _AllOf(fns)
@attrs.define(frozen=True, hash=True, eq=True)
class _AnyOf:
fns: Sequence[Filterable]
def __call__(self, item: T_co) -> bool:
return any(fn(item) for fn in self.fns)
def any_of(*fns: Filterable) -> Filterable:
return _AnyOf(fns)
def none_of(*fns: Filterable) -> Filterable:
return compose(_AnyOf(fns), not_)
def identity(x: T_co) -> T_co:
return x
def invert(func: Filterable) -> Filterable:
return compose(func, not_)

View file

@ -1,18 +0,0 @@
from __future__ import annotations
import difflib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
def unified_diff(original: str, optimized: str, file_path: Path, context_lines: int = 3) -> str:
"""Generate a unified diff between original and optimized code."""
original_lines = original.splitlines(keepends=True)
optimized_lines = optimized.splitlines(keepends=True)
diff = difflib.unified_diff(
original_lines, optimized_lines, fromfile=f"a/{file_path}", tofile=f"b/{file_path}", n=context_lines
)
return "".join(diff)

View file

@ -1,194 +0,0 @@
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from pathlib import Path
class TestOutcomeStatus(Enum):
PASSED = "passed"
FAILED = "failed"
ERROR = "error"
SKIPPED = "skipped"
@dataclass(frozen=True)
class FunctionParent:
name: str
type: str = "ClassDef"
def __str__(self) -> str:
return f"{self.type}:{self.name}"
@dataclass
class FunctionToOptimize:
function_name: str
file_path: Path
parents: list[FunctionParent] = field(default_factory=list)
starting_line: int | None = None
ending_line: int | None = None
starting_col: int | None = None
ending_col: int | None = None
is_async: bool = False
is_method: bool = False
language: str = ""
doc_start_line: int | None = None
source_code: str = ""
@property
def qualified_name(self) -> str:
if not self.parents:
return self.function_name
parent_path = ".".join(parent.name for parent in self.parents)
return f"{parent_path}.{self.function_name}"
@property
def top_level_parent_name(self) -> str:
return self.function_name if not self.parents else self.parents[0].name
@property
def class_name(self) -> str | None:
for parent in reversed(self.parents):
if parent.type == "ClassDef":
return parent.name
return None
@dataclass
class HelperFunction:
name: str
qualified_name: str
file_path: Path
source_code: str
start_line: int
end_line: int
@dataclass
class CodeContext:
target_function: FunctionToOptimize
target_code: str
target_file: Path
helper_functions: list[HelperFunction] = field(default_factory=list)
read_only_context: str = ""
imports: list[str] = field(default_factory=list)
@dataclass
class Candidate:
code: str
explanation: str
candidate_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
source: str = ""
parent_id: str = ""
code_markdown: str = ""
@dataclass
class TestOutcome:
test_id: str
status: TestOutcomeStatus
output: Any = None
duration: float = 0.0
error_message: str = ""
@dataclass
class TestResults:
passed: bool
outcomes: list[TestOutcome] = field(default_factory=list)
error: str | None = None
@dataclass
class BenchmarkResults:
timings: dict[str, float] = field(default_factory=dict)
total_time: float = 0.0
@dataclass
class ScoredCandidate:
candidate: Candidate
test_results: TestResults
benchmark_results: BenchmarkResults
speedup: float
score: float
@dataclass
class OptimizationResult:
function: FunctionToOptimize
original_code: str
optimized_code: str
speedup: float
candidate: Candidate
test_results: TestResults
benchmark_results: BenchmarkResults
diff: str = ""
explanation: str = ""
@dataclass
class GeneratedTestFile:
behavior_test_path: Path
perf_test_path: Path
behavior_test_source: str
perf_test_source: str
original_test_source: str
@dataclass
class GeneratedTestSuite:
test_files: list[GeneratedTestFile] = field(default_factory=list)
@property
def behavior_test_paths(self) -> list[Path]:
return [f.behavior_test_path for f in self.test_files]
@property
def perf_test_paths(self) -> list[Path]:
return [f.perf_test_path for f in self.test_files]
@dataclass
class FunctionCoverage:
name: str
coverage: float
executed_lines: list[int] = field(default_factory=list)
unexecuted_lines: list[int] = field(default_factory=list)
executed_branches: list[list[int]] = field(default_factory=list)
unexecuted_branches: list[list[int]] = field(default_factory=list)
@dataclass
class CoverageData:
file_path: Path
coverage: float
function_name: str
main_func_coverage: FunctionCoverage
dependent_func_coverage: FunctionCoverage | None = None
threshold_percentage: float = 60.0
@dataclass
class TestDiff:
test_id: str
baseline_output: Any = None
candidate_output: Any = None
@dataclass
class TestRepairInfo:
function_name: str
reason: str
@dataclass
class TestReviewResult:
test_index: int
functions_to_repair: list[TestRepairInfo] = field(default_factory=list)

View file

@ -1,160 +0,0 @@
from __future__ import annotations
import contextlib
import logging
import threading
import uuid
from typing import TYPE_CHECKING
from codeflash_core.config import HIGH_EFFORT_TOP_N, EffortLevel
from codeflash_core.strategy import DefaultStrategy
from codeflash_core.strategy_utils import OptimizationRuntime
from codeflash_core.ui import console, paneled_text, progress_bar
from codeflash_core.ui import logger as ui_logger
if TYPE_CHECKING:
from pathlib import Path
from codeflash_core.config import CoreConfig
from codeflash_core.models import FunctionToOptimize, OptimizationResult
from codeflash_core.protocols import LanguagePlugin
from codeflash_core.strategy_utils import OptimizationStrategy
logger = logging.getLogger(__name__)
class Optimizer:
"""Core optimization orchestrator.
Drives the discover -> index -> rank -> per-function optimization loop.
Delegates the actual optimization pipeline for each function to the active
OptimizationStrategy (defaults to DefaultStrategy).
"""
def __init__(
self, config: CoreConfig, plugin: LanguagePlugin, strategy: OptimizationStrategy | None = None
) -> None:
self.config = config
self.plugin = plugin
self.test_config = config.resolve_test_config()
# Resolve the plugin's output comparator once (or None -> fallback to ==)
self.output_comparator = getattr(plugin, "compare_outputs", None)
self.cancel_event = threading.Event()
# Share cancel event with plugin so it can abort long-running subprocess/HTTP calls
if hasattr(plugin, "cancel_event"):
object.__setattr__(plugin, "cancel_event", self.cancel_event)
self.strategy = strategy or DefaultStrategy()
def cancel(self) -> None:
self.cancel_event.set()
def is_cancelled(self) -> bool:
return self.cancel_event.is_set()
def run(self, files: list[Path], function_filter: str | None = None) -> list[OptimizationResult]:
"""Run the optimization pipeline on the given files.
Returns a list of successful optimization results.
"""
# Pre-run cleanup of leftover files from previous runs
self.plugin_cleanup()
self.cleanup_leftover_trace_files()
with progress_bar("Discovering functions...", transient=True):
functions = self.plugin.discover_functions(files)
if function_filter:
functions = [f for f in functions if function_filter in (f.function_name, f.qualified_name)]
if not functions:
ui_logger.info("No optimizable functions found.")
return []
ui_logger.info("Found %d functions to optimize.", len(functions))
# Pre-index source files for dependency analysis (call graph)
source_files = list({f.file_path for f in functions})
def on_index_progress(result: object) -> None:
pass
with progress_bar("Building call graph...", transient=True):
self.plugin.build_index(source_files, on_progress=on_index_progress)
# Rank functions by impact (e.g. dependency count)
functions = self.plugin.rank_functions(functions)
results: list[OptimizationResult] = []
skipped = 0
cancelled = False
try:
for i, function in enumerate(functions):
if self.is_cancelled():
cancelled = True
break
console.rule(f"[bold][{i + 1}/{len(functions)}] {function.qualified_name}[/bold]")
# Escalate top-N functions to HIGH effort when running at MEDIUM
original_effort = self.config.effort
if i < HIGH_EFFORT_TOP_N and self.config.effort == EffortLevel.MEDIUM.value:
self.config.effort = EffortLevel.HIGH.value
result = self.optimize_function(function)
self.config.effort = original_effort
if result is not None:
results.append(result)
ui_logger.info("Optimized %s%.2fx speedup", function.qualified_name, result.speedup)
else:
skipped += 1
ui_logger.info("No improvement found for %s", function.qualified_name)
except KeyboardInterrupt:
ui_logger.warning("Keyboard interrupt received. Cleaning up…")
cancelled = True
self.cancel_event.set()
finally:
self.plugin_cleanup()
self.cleanup_leftover_trace_files()
console.rule()
paneled_text(f"{len(functions)} analyzed, {len(results)} optimized, {skipped} skipped", title="Summary")
return results
def optimize_function(self, function: FunctionToOptimize) -> OptimizationResult | None:
"""Attempt to optimize a single function. Delegates to the active strategy."""
if self.is_cancelled():
return None
runtime = OptimizationRuntime(
plugin=self.plugin,
config=self.config,
test_config=self.test_config,
cancel_event=self.cancel_event,
output_comparator=self.output_comparator,
trace_id=str(uuid.uuid4()),
)
return self.strategy.optimize_function(function, runtime)
# -- Cleanup helpers (shared across all strategies) -------------------------
def cleanup_leftover_trace_files(self) -> None:
"""Remove leftover .trace files from previous runs."""
tests_root = self.test_config.tests_root
if not tests_root.exists():
return
leftover = list(tests_root.glob("*.trace"))
if leftover:
logger.debug("Cleaning up %d leftover trace file(s)", len(leftover))
for p in leftover:
with contextlib.suppress(OSError):
p.unlink(missing_ok=True)
def plugin_cleanup(self) -> None:
"""Delegate cleanup of language-specific leftover files to the plugin."""
if hasattr(self.plugin, "cleanup_run"):
try:
self.plugin.cleanup_run(self.test_config.tests_root)
except Exception:
logger.debug("Plugin cleanup_run failed", exc_info=True)

View file

@ -1,290 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from codeflash_core.config import TestConfig
from codeflash_core.models import (
BenchmarkResults,
Candidate,
CodeContext,
CoverageData,
FunctionToOptimize,
GeneratedTestSuite,
OptimizationResult,
ScoredCandidate,
TestDiff,
TestResults,
TestReviewResult,
)
@runtime_checkable
class LanguagePlugin(Protocol):
"""Protocol that language packages must implement.
A language plugin provides all language-specific functionality:
discovering functions, extracting code context, running tests,
replacing code, running benchmarks, and formatting.
"""
def discover_functions(self, paths: list[Path]) -> list[FunctionToOptimize]:
"""Discover optimizable functions in the given file paths."""
...
def build_index(self, files: list[Path], on_progress: Callable[[Any], None] | None = None) -> None:
"""Pre-index source files for dependency analysis (e.g. call graph).
Called after discovery, before optimization begins. Plugins that
maintain a dependency resolver should index the given files here.
The optional on_progress callback is called once per file with an
implementation-defined result object.
"""
...
def rank_functions(
self,
functions: list[FunctionToOptimize],
trace_file: Path | None = None,
test_counts: dict[tuple[Path, str], int] | None = None,
) -> list[FunctionToOptimize]:
"""Rank functions in optimization order (most impactful first).
Ranking priority:
1. Trace-based addressable time (when trace_file is provided)
2. Dependency count from the call graph (fallback)
3. Existing unit test count as secondary sort key (when test_counts provided)
Returns the functions unchanged if no ranking is possible.
"""
...
def get_dependency_counts(self) -> dict[str, int]:
"""Return {qualified_name: callee_count} from the most recent ranking.
Called after rank_functions() so the UI can display per-function
dependency information. Plugins that don't track a call graph may
return an empty dict.
"""
...
def get_candidates(self, context: CodeContext, trace_id: str = "") -> list[Candidate]:
"""Request optimization candidates from the AI service."""
...
def extract_context(self, function: FunctionToOptimize) -> CodeContext:
"""Extract all code context needed to optimize a function."""
...
@overload
def run_tests(
self,
test_config: TestConfig,
test_files: list[Path] | None = ...,
test_iteration: int = ...,
enable_coverage: Literal[False] = ...,
) -> TestResults: ...
@overload
def run_tests(
self,
test_config: TestConfig,
test_files: list[Path] | None = ...,
test_iteration: int = ...,
enable_coverage: Literal[True] = ...,
) -> tuple[TestResults, CoverageData | None]: ...
def run_tests(
self,
test_config: TestConfig,
test_files: list[Path] | None = None,
test_iteration: int = 0,
enable_coverage: bool = False,
) -> TestResults | tuple[TestResults, CoverageData | None]:
"""Run tests and return structured results.
If test_files is provided, run only those files.
Otherwise discover test files from test_config.
test_iteration is passed as CODEFLASH_TEST_ITERATION env var.
When enable_coverage is True, returns (TestResults, CoverageData | None).
"""
...
def replace_function(self, file: Path, function: FunctionToOptimize, new_code: str) -> None:
"""Replace a function's source code in a file."""
...
def restore_function(self, file: Path, function: FunctionToOptimize, original_code: str) -> None:
"""Restore a function's original source code in a file."""
...
def run_benchmarks(
self,
function: FunctionToOptimize,
test_config: TestConfig,
test_files: list[Path] | None = None,
test_iteration: int = 0,
) -> BenchmarkResults:
"""Run benchmarks for a function and return timing data.
If test_files is provided, run only those files.
Otherwise discover test files from test_config.
test_iteration is passed as CODEFLASH_TEST_ITERATION env var.
"""
...
def format_code(self, code: str, file: Path) -> str:
"""Format code according to the project's style."""
...
def validate_candidate(self, code: str) -> bool:
"""Return True if the candidate code is syntactically valid."""
...
def normalize_code(self, code: str) -> str:
"""Normalize code for deduplication (e.g. strip comments, whitespace, docstrings)."""
...
# -- Phase 1: Test Generation -----------------------------------------------
def generate_tests(
self, function: FunctionToOptimize, context: CodeContext, test_config: TestConfig, trace_id: str = ""
) -> GeneratedTestSuite | None:
"""Generate regression tests for the target function."""
...
# -- Phase 2: Split behavioral / performance test running --------------------
def run_behavioral_tests(self, test_files: list[Path], test_config: TestConfig) -> TestResults:
"""Run behavioral tests and return pass/fail with captured outputs."""
...
def run_performance_tests(
self, test_files: list[Path], function: FunctionToOptimize, test_config: TestConfig
) -> BenchmarkResults:
"""Run performance-instrumented tests and return timing data."""
...
# -- Phase 3: Multi-round candidate generation ------------------------------
def run_line_profiler(
self, function: FunctionToOptimize, test_config: TestConfig, test_files: list[Path] | None = None
) -> str:
"""Run line profiler on the function and return formatted profiler output.
Returns an empty string if profiling is not possible (e.g. JIT-decorated code).
"""
...
def get_line_profiler_candidates(
self, context: CodeContext, line_profile_data: str, trace_id: str = ""
) -> list[Candidate]:
"""Generate candidates guided by line profiler hotspot data."""
...
def repair_candidate(
self, context: CodeContext, candidate: Candidate, test_diffs: list[TestDiff], trace_id: str = ""
) -> Candidate | None:
"""Fix a failing candidate using test failure info."""
...
def refine_candidate(
self, context: CodeContext, candidate: ScoredCandidate, baseline_bench: BenchmarkResults, trace_id: str = ""
) -> list[Candidate]:
"""Refine a passing candidate for further improvement."""
...
def adaptive_optimize(
self, context: CodeContext, scored: list[ScoredCandidate], trace_id: str = ""
) -> Candidate | None:
"""Combine insights from evaluated candidates."""
...
# -- Phase 4: Test review & repair ------------------------------------------
def review_generated_tests(
self, suite: GeneratedTestSuite, context: CodeContext, test_results: TestResults, trace_id: str = ""
) -> list[TestReviewResult]:
"""Review generated tests for quality issues."""
...
def repair_generated_tests(
self,
suite: GeneratedTestSuite,
reviews: list[TestReviewResult],
context: CodeContext,
trace_id: str = "",
previous_repair_errors: dict[str, str] | None = None,
coverage_data: CoverageData | None = None,
) -> GeneratedTestSuite | None:
"""Repair generated tests based on review feedback."""
...
# -- Phase 5: AI-assisted ranking & explanation -----------------------------
def rank_candidates(
self, scored: list[ScoredCandidate], context: CodeContext, trace_id: str = ""
) -> list[int] | None:
"""Rank candidates using AI. Returns indices in decreasing preference order, or None."""
...
def generate_explanation(
self, result: OptimizationResult, context: CodeContext, trace_id: str = "", annotated_tests: str = ""
) -> str:
"""Generate a human-readable explanation for the winning optimization."""
...
# -- Cleanup & environment -------------------------------------------------
def cleanup_run(self, tests_root: Path) -> None:
"""Clean up leftover files from previous or current runs.
Called before and after the optimization loop. Implementations should
remove instrumented test files, temporary return-value files, trace
files, and any shared temp directories their tooling creates.
"""
...
def compare_outputs(self, baseline_output: object, candidate_output: object) -> bool:
"""Compare two captured test outputs for equivalence.
Called during verification to decide whether a candidate preserved
behavior. Implementations may use deep/structural comparison
(e.g. handling NaN, custom objects). The default core fallback is ``==``.
"""
...
def validate_environment(self, config: Any) -> bool:
"""Validate that the environment is ready to run optimizations.
Called before the optimization loop starts. Implementations should
check that required tools (formatters, test runners, etc.) are
installed and accessible. Return True if everything is OK.
"""
...
# -- Phase 6: PR creation & result logging ----------------------------------
def create_pr(
self,
result: OptimizationResult,
context: CodeContext,
trace_id: str = "",
generated_tests: GeneratedTestSuite | None = None,
) -> str | None:
"""Create a pull request with the optimization. Returns the PR URL or None."""
...
def log_results(
self,
result: OptimizationResult,
trace_id: str,
all_speedups: dict[str, float] | None = None,
all_runtimes: dict[str, float] | None = None,
all_correct: dict[str, bool] | None = None,
) -> None:
"""Log optimization results to the backend."""
...

View file

@ -1,34 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash_core.models import BenchmarkResults, ScoredCandidate
def compute_speedup(baseline: BenchmarkResults, candidate: BenchmarkResults) -> float:
"""Compute speedup as percentage improvement: (baseline - candidate) / candidate.
Matches original codeflash performance_gain formula.
Returns 0.0 if candidate time is zero (no improvement measurable).
A positive value means the candidate is faster (e.g. 1.0 = 100% faster).
"""
if candidate.total_time <= 0:
return 0.0
return (baseline.total_time - candidate.total_time) / candidate.total_time
def score_candidate(speedup: float) -> float:
"""Score a candidate based on its speedup.
Currently score == speedup. This is the extension point for adding
more signals (code complexity, diff size, etc.) in the future.
"""
return speedup
def select_best(candidates: list[ScoredCandidate]) -> ScoredCandidate | None:
"""Select the best candidate by score. Returns None if list is empty."""
if not candidates:
return None
return max(candidates, key=lambda c: c.score)

View file

@ -1,311 +0,0 @@
"""Optimization strategies for the core optimizer.
An OptimizationStrategy controls the full pipeline for optimizing a single
function: context extraction, candidate generation, test review, evaluation,
ranking, and explanation. The Optimizer delegates to the active strategy,
keeping discovery, indexing, and the per-function loop in the orchestrator.
"""
from __future__ import annotations
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, ClassVar
from codeflash_core.config import EffortKeys, get_effort_value
from codeflash_core.diff import unified_diff
from codeflash_core.models import OptimizationResult
from codeflash_core.strategy_evaluation import DefaultStrategyEvaluationMixin
from codeflash_core.strategy_utils import StageSpec, cleanup_generated_tests, log_optimization_run
from codeflash_core.ui import code_print, progress_bar
from codeflash_core.ui import logger as ui_logger
if TYPE_CHECKING:
from pathlib import Path
from codeflash_core.models import Candidate, CodeContext, FunctionToOptimize, GeneratedTestSuite, ScoredCandidate
from codeflash_core.strategy_utils import OptimizationRuntime
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Default strategy
# ---------------------------------------------------------------------------
class DefaultStrategy(DefaultStrategyEvaluationMixin):
"""The default optimization strategy.
Pipeline: context -> parallel testgen + candidates -> test review/repair ->
baseline tests & benchmarks -> multi-round candidate evaluation ->
AI-assisted ranking -> explanation.
"""
stages: ClassVar[list[StageSpec]] = [
StageSpec("context", "Extracting source code and dependencies to build the optimization context for the AI."),
StageSpec("generating", "Generating unit tests for correctness validation and optimized candidates."),
StageSpec(
"test_review", "Reviewing generated tests for correctness and repairing any issues before validation."
),
StageSpec(
"baseline",
"Running the original code against tests and benchmarks to establish a performance reference point.",
),
StageSpec("evaluating", "Testing each candidate for correctness, then benchmarking the ones that pass."),
StageSpec(
"ranking", "Comparing candidate benchmarks against the baseline to find the fastest correct optimization."
),
StageSpec(
"explaining",
"Generating a human-readable explanation of what the best optimization changed and why it's faster.",
),
]
def optimize_function(
self, function: FunctionToOptimize, runtime: OptimizationRuntime
) -> OptimizationResult | None:
context: CodeContext | None = None
candidates: list[Candidate] | None = None
generated_tests: GeneratedTestSuite | None = None
original_generated_tests: GeneratedTestSuite | None = None
scored: list[ScoredCandidate] = []
result: OptimizationResult | None = None
stage = "context"
exit_reason = ""
try:
head = self.extract_and_generate(function, runtime)
if head is None:
exit_reason = "cancelled" if runtime.is_cancelled() else "no_candidates"
return None
context, candidates, generated_tests = head
original_generated_tests = generated_tests
# Phase 2: Review and repair generated tests
stage = "test_review"
with progress_bar("Reviewing and repairing tests..."):
generated_tests = self.review_and_repair_tests(generated_tests, context, runtime)
if runtime.is_cancelled():
exit_reason = "cancelled"
return None
# Determine test files
behavior_files: list[Path] | None = None
perf_files: list[Path] | None = None
if generated_tests and generated_tests.test_files:
behavior_files = generated_tests.behavior_test_paths
perf_files = generated_tests.perf_test_paths
# Run baseline tests
stage = "baseline"
with progress_bar("Running baseline tests..."):
baseline_tests = runtime.plugin.run_tests(runtime.test_config, test_files=behavior_files)
if runtime.is_cancelled():
exit_reason = "cancelled"
return None
if not baseline_tests.passed:
exit_reason = "baseline_failed"
ui_logger.warning("Baseline tests failed for %s, skipping", function.qualified_name)
return None
# Run baseline benchmarks
with progress_bar("Running baseline benchmarks..."):
baseline_bench = runtime.plugin.run_benchmarks(function, runtime.test_config, test_files=perf_files)
if runtime.is_cancelled():
exit_reason = "cancelled"
return None
# Line profiler
lp_candidates = self.get_line_profiler_candidates(function, context, runtime, perf_files)
if lp_candidates:
candidates.extend(lp_candidates)
if runtime.is_cancelled():
exit_reason = "cancelled"
return None
# Phase 3: Multi-round candidate evaluation
stage = "evaluating"
ui_logger.info("Evaluating %d candidates...", len(candidates))
scored, all_speedups, all_runtimes, all_correct = self.evaluate_candidates(
function,
context,
candidates,
baseline_tests,
baseline_bench,
runtime=runtime,
behavior_test_files=behavior_files,
perf_test_files=perf_files,
)
if runtime.is_cancelled():
exit_reason = "cancelled"
return None
# Phase 4+5: Rank, explain, log, finish
stage = "ranking"
with progress_bar("Ranking results..."):
result = self.rank_explain_finish(
function, context, scored, generated_tests, runtime, all_speedups, all_runtimes, all_correct
)
if result:
stage = "done"
exit_reason = "optimized"
code_print(result.diff)
else:
exit_reason = "cancelled" if runtime.is_cancelled() else "no_improvement"
return result
finally:
log_optimization_run(
function,
runtime,
context=context,
candidates=candidates,
generated_tests=generated_tests,
scored=scored or None,
result=result,
stage_reached=stage,
exit_reason=exit_reason,
)
if original_generated_tests:
cleanup_generated_tests(original_generated_tests)
# -- Building blocks (override individually or call from custom strategies) --
def extract_and_generate(
self, function: FunctionToOptimize, runtime: OptimizationRuntime
) -> tuple[CodeContext, list[Candidate], GeneratedTestSuite | None] | None:
"""Context extraction + parallel test/candidate generation.
Returns (context, candidates, generated_tests) or None if cancelled/no candidates.
"""
with progress_bar("Extracting context...", transient=True):
context = runtime.plugin.extract_context(function)
if runtime.is_cancelled():
return None
with progress_bar("Generating tests and candidates..."):
generated_tests, candidates = self.generate_tests_and_candidates(function, context, runtime)
if runtime.is_cancelled():
return None
if not candidates:
ui_logger.info("No candidates returned for %s", function.qualified_name)
return None
ui_logger.info("Received %d candidates for %s", len(candidates), function.qualified_name)
return context, candidates, generated_tests
def rank_explain_finish(
self,
function: FunctionToOptimize,
context: CodeContext,
scored: list[ScoredCandidate],
generated_tests: GeneratedTestSuite | None,
runtime: OptimizationRuntime,
all_speedups: dict[str, float],
all_runtimes: dict[str, float],
all_correct: dict[str, bool],
) -> OptimizationResult | None:
"""Rank candidates, explain the best, log results, and emit completion events."""
best = self.select_best_with_ranking(scored, context, runtime)
if runtime.is_cancelled():
return None
if best is None or best.speedup <= 0:
return None
diff = unified_diff(context.target_code, best.candidate.code, function.file_path)
result = OptimizationResult(
function=function,
original_code=context.target_code,
optimized_code=best.candidate.code,
speedup=best.speedup,
candidate=best.candidate,
test_results=best.test_results,
benchmark_results=best.benchmark_results,
diff=diff,
)
annotated_tests = ""
if generated_tests and generated_tests.test_files:
annotated_tests = "\n\n".join(
tf.original_test_source for tf in generated_tests.test_files if tf.original_test_source
)
explanation = runtime.plugin.generate_explanation(
result, context, trace_id=runtime.trace_id, annotated_tests=annotated_tests
)
if explanation:
result.explanation = explanation
runtime.plugin.log_results(
result, runtime.trace_id, all_speedups=all_speedups, all_runtimes=all_runtimes, all_correct=all_correct
)
if runtime.config.create_pr:
try:
runtime.plugin.create_pr(result, context, trace_id=runtime.trace_id, generated_tests=generated_tests)
except Exception:
logger.debug("PR creation failed", exc_info=True)
return result
def generate_tests_and_candidates(
self, function: FunctionToOptimize, context: CodeContext, runtime: OptimizationRuntime
) -> tuple[GeneratedTestSuite | None, list[Candidate]]:
"""Generate tests and fetch candidates in parallel."""
with ThreadPoolExecutor(max_workers=2) as executor:
tests_future = executor.submit(
runtime.plugin.generate_tests, function, context, runtime.test_config, runtime.trace_id
)
candidates_future = executor.submit(runtime.plugin.get_candidates, context, runtime.trace_id)
while True:
if runtime.is_cancelled():
tests_future.cancel()
candidates_future.cancel()
break
if tests_future.done() and candidates_future.done():
break
runtime.cancel_event.wait(timeout=0.1)
try:
generated_tests = (
tests_future.result() if tests_future.done() and not tests_future.cancelled() else None
)
except Exception:
logger.debug("Test generation failed", exc_info=True)
generated_tests = None
try:
candidates = (
candidates_future.result() if candidates_future.done() and not candidates_future.cancelled() else []
)
except Exception:
logger.debug("Candidate generation failed", exc_info=True)
candidates = []
return generated_tests, candidates
def get_line_profiler_candidates(
self,
function: FunctionToOptimize,
context: CodeContext,
runtime: OptimizationRuntime,
perf_test_files: list[Path] | None,
) -> list[Candidate]:
"""Run line profiler on baseline and fetch LP-guided candidates."""
n_lp = get_effort_value(EffortKeys.N_OPTIMIZER_LP_CANDIDATES, runtime.config.effort)
if n_lp <= 0:
return []
try:
lp_data = runtime.plugin.run_line_profiler(function, runtime.test_config, test_files=perf_test_files)
if not lp_data:
return []
lp_candidates = runtime.plugin.get_line_profiler_candidates(context, lp_data, runtime.trace_id)
for c in lp_candidates:
c.source = "line_profiler"
return lp_candidates
except Exception:
logger.debug("Line profiler step failed for %s", function.qualified_name, exc_info=True)
return []

View file

@ -1,283 +0,0 @@
"""Mixin: candidate evaluation, filtering, and test review/repair."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from codeflash_core.config import MAX_TEST_REPAIR_CYCLES, EffortKeys, get_effort_value
from codeflash_core.models import GeneratedTestSuite, ScoredCandidate, TestOutcomeStatus
from codeflash_core.ranking import compute_speedup, score_candidate, select_best
from codeflash_core.strategy_utils import MIN_CORRECT_CANDIDATES, compute_test_diffs, restore_test_snapshots
from codeflash_core.ui import logger as ui_logger
from codeflash_core.ui import progress_bar
from codeflash_core.verification import is_equivalent
if TYPE_CHECKING:
from pathlib import Path
from codeflash_core.models import (
BenchmarkResults,
Candidate,
CodeContext,
CoverageData,
FunctionToOptimize,
TestResults,
)
from codeflash_core.strategy_utils import OptimizationRuntime
StrategyBase = object
logger = logging.getLogger(__name__)
class DefaultStrategyEvaluationMixin(StrategyBase):
def evaluate_candidates(
self,
function: FunctionToOptimize,
context: CodeContext,
candidates: list[Candidate],
baseline_tests: TestResults,
baseline_bench: BenchmarkResults,
runtime: OptimizationRuntime,
behavior_test_files: list[Path] | None = None,
perf_test_files: list[Path] | None = None,
) -> tuple[list[ScoredCandidate], dict[str, float], dict[str, float], dict[str, bool]]:
"""Evaluate candidates with multi-round repair, refinement, and adaptive optimization.
Returns (scored_candidates, all_speedups, all_runtimes, all_correct) where the dicts
accumulate data for ALL candidates (including failed ones) keyed by candidate_id.
"""
effort = runtime.config.effort
scored: list[ScoredCandidate] = []
all_speedups: dict[str, float] = {}
all_runtimes: dict[str, float] = {}
all_correct: dict[str, bool] = {}
queue = self.filter_candidates(function, context, candidates, runtime)
processed = 0
max_total = 30
max_repairs: int = get_effort_value(EffortKeys.MAX_CODE_REPAIRS_PER_TRACE, effort)
repair_unmatched_limit: float = get_effort_value(EffortKeys.REPAIR_UNMATCHED_PERCENTAGE_LIMIT, effort)
top_refinement: int = get_effort_value(EffortKeys.TOP_VALID_CANDIDATES_FOR_REFINEMENT, effort)
max_adaptive: int = get_effort_value(EffortKeys.MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE, effort)
adaptive_threshold: int = get_effort_value(EffortKeys.ADAPTIVE_OPTIMIZATION_THRESHOLD, effort)
repair_counter = 0
adaptive_counter = 0
file_snapshots = self.capture_file_snapshots(function, context)
while queue and processed < max_total:
if runtime.is_cancelled():
break
candidate = queue.pop(0)
processed += 1
ui_logger.info(
"Testing candidate %d/%d [%s]",
processed,
len(candidates) + processed - len(queue) - 1,
candidate.source,
)
if hasattr(runtime.plugin, "pending_code_markdown"):
runtime.plugin.pending_code_markdown = candidate.code_markdown
runtime.plugin.replace_function(function.file_path, function, candidate.code)
try:
with progress_bar(f"Running tests for candidate {processed}...", transient=True):
test_results = runtime.plugin.run_tests(
runtime.test_config, test_files=behavior_test_files, test_iteration=processed
)
if not is_equivalent(baseline_tests, test_results, comparator=runtime.output_comparator):
ui_logger.info("Candidate %s failed equivalence check", candidate.candidate_id)
all_correct[candidate.candidate_id] = False
# Phase 3: Try repair if test diffs are manageable
# Only repair first-pass candidates (optimize, line_profiler)
successful_count = sum(1 for v in all_correct.values() if v)
if (
not runtime.is_cancelled()
and candidate.source in ("optimize", "line_profiler")
and repair_counter < max_repairs
and successful_count < MIN_CORRECT_CANDIDATES
):
diffs = compute_test_diffs(baseline_tests, test_results)
total_tests = max(len(baseline_tests.outcomes), 1)
if diffs and len(diffs) / total_tests <= repair_unmatched_limit:
repaired = runtime.plugin.repair_candidate(
context, candidate, diffs, trace_id=runtime.trace_id
)
if repaired is not None:
repair_counter += 1
queue.append(repaired)
continue
with progress_bar(f"Benchmarking candidate {processed}...", transient=True):
bench = runtime.plugin.run_benchmarks(
function, runtime.test_config, test_files=perf_test_files, test_iteration=processed
)
speedup = compute_speedup(baseline_bench, bench)
ui_logger.info("Candidate %s passed — %.2fx speedup", candidate.candidate_id, speedup)
all_correct[candidate.candidate_id] = True
all_speedups[candidate.candidate_id] = speedup
all_runtimes[candidate.candidate_id] = bench.total_time
scored.append(
ScoredCandidate(
candidate=candidate,
test_results=test_results,
benchmark_results=bench,
speedup=speedup,
score=score_candidate(speedup),
)
)
# Phase 3: Try refinement for candidates with speedup (skip already-refined)
if not runtime.is_cancelled():
refinement_eligible = sum(1 for s in scored if s.speedup > 0 and s.candidate.source != "refine")
if speedup > 0 and candidate.source != "refine" and refinement_eligible <= top_refinement:
refined = runtime.plugin.refine_candidate(
context, scored[-1], baseline_bench, trace_id=runtime.trace_id
)
if refined:
queue.extend(refined)
# Phase 3: Try adaptive after enough candidates scored
if (
not runtime.is_cancelled()
and len(scored) >= 3
and adaptive_counter < max_adaptive
and adaptive_threshold > 0
):
adaptive_counter += 1
adaptive = runtime.plugin.adaptive_optimize(context, scored, trace_id=runtime.trace_id)
if adaptive is not None:
queue.append(adaptive)
finally:
for path, original_content in file_snapshots.items():
path.write_text(original_content, encoding="utf-8")
return scored, all_speedups, all_runtimes, all_correct
def filter_candidates(
self,
function: FunctionToOptimize,
context: CodeContext,
candidates: list[Candidate],
runtime: OptimizationRuntime,
) -> list[Candidate]:
"""Validate syntax, format, and deduplicate candidates."""
normalized_original = runtime.plugin.normalize_code(context.target_code.strip())
seen_normalized: set[str] = {normalized_original}
queue: list[Candidate] = []
for c in candidates:
if not runtime.plugin.validate_candidate(c.code):
ui_logger.info("Candidate %s has invalid syntax, skipping", c.candidate_id)
continue
c.code = runtime.plugin.format_code(c.code, function.file_path)
norm = runtime.plugin.normalize_code(c.code.strip())
if norm in seen_normalized:
ui_logger.info("Candidate %s is identical/duplicate, skipping", c.candidate_id)
continue
seen_normalized.add(norm)
queue.append(c)
return queue
def capture_file_snapshots(self, function: FunctionToOptimize, context: CodeContext) -> dict[Path, str]:
"""Snapshot target + helper files for reliable restoration after code replacement."""
snapshots: dict[Path, str] = {}
if function.file_path.exists():
snapshots[function.file_path] = function.file_path.read_text("utf-8")
for h in context.helper_functions:
if h.file_path.exists() and h.file_path not in snapshots:
snapshots[h.file_path] = h.file_path.read_text("utf-8")
return snapshots
def review_and_repair_tests(
self, generated_tests: GeneratedTestSuite | None, context: CodeContext, runtime: OptimizationRuntime
) -> GeneratedTestSuite | None:
"""Review and repair generated tests up to MAX_TEST_REPAIR_CYCLES iterations, then drop still-failing files."""
if not generated_tests or not generated_tests.test_files:
return generated_tests
# Snapshot test file contents before repair so we can revert on failure
pre_repair_snapshots: dict[int, tuple[str, str, str]] = {}
for i, tf in enumerate(generated_tests.test_files):
pre_repair_snapshots[i] = (tf.original_test_source, tf.behavior_test_source, tf.perf_test_source)
previous_repair_errors: dict[str, str] = {}
coverage_data: CoverageData | None = None
for iteration in range(MAX_TEST_REPAIR_CYCLES):
if runtime.is_cancelled():
return generated_tests
behavior_files = generated_tests.behavior_test_paths
# Collect coverage on first iteration for repair guidance
if iteration == 0:
test_results, cov_data = runtime.plugin.run_tests(
runtime.test_config, test_files=behavior_files, enable_coverage=True
)
if cov_data is not None:
coverage_data = cov_data
else:
test_results = runtime.plugin.run_tests(runtime.test_config, test_files=behavior_files)
if test_results.passed:
return generated_tests
# Collect error messages from failing tests for next repair attempt
failing_outcomes = [o for o in test_results.outcomes if o.status != TestOutcomeStatus.PASSED]
for outcome in failing_outcomes:
if outcome.error_message:
previous_repair_errors[outcome.test_id] = outcome.error_message
reviews = runtime.plugin.review_generated_tests(generated_tests, context, test_results, runtime.trace_id)
if not reviews or all(not r.functions_to_repair for r in reviews):
restore_test_snapshots(generated_tests, pre_repair_snapshots)
break
repaired = runtime.plugin.repair_generated_tests(
generated_tests,
reviews,
context,
runtime.trace_id,
previous_repair_errors=previous_repair_errors or None,
coverage_data=coverage_data,
)
if repaired is None:
restore_test_snapshots(generated_tests, pre_repair_snapshots)
break
generated_tests = repaired
# Final pass: drop individual test files that still fail
passing_files = []
for tf in generated_tests.test_files:
if runtime.is_cancelled():
return generated_tests
run_result = runtime.plugin.run_tests(runtime.test_config, test_files=[tf.behavior_test_path])
if run_result.passed:
passing_files.append(tf)
else:
ui_logger.info("Dropping failing test file: %s", tf.behavior_test_path)
if not passing_files:
return None
return GeneratedTestSuite(test_files=passing_files)
def select_best_with_ranking(
self, scored: list[ScoredCandidate], context: CodeContext, runtime: OptimizationRuntime
) -> ScoredCandidate | None:
"""Select best candidate, optionally using AI ranking for multiple candidates."""
if not scored:
return None
if len(scored) > 1:
ranking = runtime.plugin.rank_candidates(scored, context, trace_id=runtime.trace_id)
if ranking and ranking[0] < len(scored):
return scored[ranking[0]]
return select_best(scored)

View file

@ -1,184 +0,0 @@
"""Shared types and utilities for optimization strategies."""
from __future__ import annotations
import contextlib
import json
import logging
import threading # noqa: TC003 - used at runtime by OptimizationRuntime.is_cancelled()
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from codeflash_core.models import TestDiff
if TYPE_CHECKING:
from collections.abc import Callable
from codeflash_core.config import CoreConfig, TestConfig
from codeflash_core.models import (
Candidate,
CodeContext,
FunctionToOptimize,
GeneratedTestSuite,
OptimizationResult,
ScoredCandidate,
TestResults,
)
from codeflash_core.protocols import LanguagePlugin
logger = logging.getLogger(__name__)
MIN_CORRECT_CANDIDATES = 2
@dataclass
class StageSpec:
"""Describes a single stage in the optimization pipeline."""
key: str
description: str
@dataclass
class OptimizationRuntime:
"""Everything a strategy needs to orchestrate an optimization."""
plugin: LanguagePlugin
config: CoreConfig
test_config: TestConfig
cancel_event: threading.Event
output_comparator: Callable[..., bool] | None
trace_id: str
def is_cancelled(self) -> bool:
return self.cancel_event.is_set()
@runtime_checkable
class OptimizationStrategy(Protocol):
"""Protocol for optimization strategies.
Implement this to define a custom optimization pipeline for a single function.
The Optimizer handles discovery, indexing, ranking, and the per-function loop;
the strategy controls everything within a single function's optimization.
"""
def optimize_function(
self, function: FunctionToOptimize, runtime: OptimizationRuntime
) -> OptimizationResult | None: ...
# ---------------------------------------------------------------------------
# Shared utilities
# ---------------------------------------------------------------------------
def cleanup_generated_tests(generated_tests: GeneratedTestSuite | None) -> None:
"""Remove generated test files from disk."""
if not generated_tests:
return
for tf in generated_tests.test_files:
for path in (tf.behavior_test_path, tf.perf_test_path):
with contextlib.suppress(OSError):
path.unlink(missing_ok=True)
def restore_test_snapshots(generated_tests: GeneratedTestSuite, snapshots: dict[int, tuple[str, str, str]]) -> None:
"""Restore test file contents from pre-repair snapshots."""
for i, tf in enumerate(generated_tests.test_files):
if i not in snapshots:
continue
orig_source, orig_behavior, orig_perf = snapshots[i]
tf.original_test_source = orig_source
tf.behavior_test_source = orig_behavior
tf.perf_test_source = orig_perf
with contextlib.suppress(OSError):
tf.behavior_test_path.write_text(orig_behavior, encoding="utf-8")
with contextlib.suppress(OSError):
tf.perf_test_path.write_text(orig_perf, encoding="utf-8")
def compute_test_diffs(baseline: TestResults, candidate: TestResults) -> list[TestDiff]:
"""Compute test outcome differences between baseline and candidate runs."""
diffs: list[TestDiff] = []
baseline_by_id = {o.test_id: o for o in baseline.outcomes}
candidate_by_id = {o.test_id: o for o in candidate.outcomes}
for test_id, base_outcome in baseline_by_id.items():
cand_outcome = candidate_by_id.get(test_id)
if (
cand_outcome is None
or base_outcome.status != cand_outcome.status
or base_outcome.output != cand_outcome.output
):
diffs.append(
TestDiff(
test_id=test_id,
baseline_output=base_outcome.output,
candidate_output=cand_outcome.output if cand_outcome else None,
)
)
return diffs
def log_optimization_run(
function: FunctionToOptimize,
runtime: OptimizationRuntime,
context: CodeContext | None = None,
candidates: list[Candidate] | None = None,
generated_tests: GeneratedTestSuite | None = None,
scored: list[ScoredCandidate] | None = None,
result: OptimizationResult | None = None,
stage_reached: str = "",
exit_reason: str = "",
) -> None:
"""Append a JSON record of the optimization run to codeflash_trace.log."""
record: dict[str, Any] = {
"trace_id": runtime.trace_id,
"function": function.qualified_name,
"file": str(function.file_path),
"stage_reached": stage_reached,
"exit_reason": exit_reason,
}
if context:
record["context"] = {
"target_code_lines": len(context.target_code.splitlines()),
"read_only_context_lines": len(context.read_only_context.splitlines()) if context.read_only_context else 0,
"imports": len(context.imports),
"helpers": [{"name": h.qualified_name, "file": str(h.file_path)} for h in context.helper_functions],
}
record["candidates_count"] = len(candidates) if candidates else 0
if candidates:
record["candidates"] = [{"id": c.candidate_id, "source": c.source} for c in candidates]
record["tests_count"] = len(generated_tests.test_files) if generated_tests and generated_tests.test_files else 0
if scored:
record["evaluation"] = [
{
"id": s.candidate.candidate_id,
"source": s.candidate.source,
"speedup": s.speedup,
"score": s.score,
"passed": s.test_results.passed,
}
for s in scored
]
if result:
record["result"] = {
"speedup": result.speedup,
"candidate_id": result.candidate.candidate_id,
"explanation": result.explanation,
"diff": result.diff,
}
log_path = runtime.config.project_root / "codeflash_trace.log"
try:
with log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(record) + "\n")
except OSError:
logger.debug("Failed to write optimization trace log", exc_info=True)

View file

@ -1,4 +0,0 @@
from codeflash_core.telemetry.posthog_cf import PostHogClient
from codeflash_core.telemetry.sentry_cf import init_sentry
__all__ = ["PostHogClient", "init_sentry"]

View file

@ -1,44 +0,0 @@
from __future__ import annotations
import contextlib
import logging
from typing import Any
logger = logging.getLogger(__name__)
class PostHogClient:
"""Wrapper around the PostHog analytics client."""
instance: PostHogClient | None = None
def __init__(self, api_key: str, enabled: bool = True) -> None:
self.enabled = enabled and bool(api_key)
self.ph: Any = None
if self.enabled:
try:
from posthog import Posthog
self.ph = Posthog(api_key, host="https://us.i.posthog.com")
except Exception:
logger.debug("Failed to initialize PostHog", exc_info=True)
self.enabled = False
@classmethod
def initialize(cls, api_key: str, enabled: bool = True) -> PostHogClient:
cls.instance = cls(api_key, enabled=enabled)
return cls.instance
def capture(self, distinct_id: str, event: str, properties: dict[str, Any] | None = None) -> None:
if not self.enabled or self.ph is None:
return
try:
self.ph.capture(distinct_id, event, properties=properties or {})
except Exception:
logger.debug("PostHog capture failed", exc_info=True)
def shutdown(self) -> None:
if self.ph is not None:
with contextlib.suppress(Exception):
self.ph.shutdown()

View file

@ -1,18 +0,0 @@
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
def init_sentry(dsn: str, enabled: bool = True) -> None:
"""Initialize Sentry error tracking."""
if not enabled or not dsn:
return
try:
import sentry_sdk
sentry_sdk.init(dsn=dsn, traces_sample_rate=0.0)
except Exception:
logger.debug("Failed to initialize Sentry", exc_info=True)

View file

@ -1,21 +0,0 @@
from __future__ import annotations
from codeflash_core.ui.console import (
code_print,
console,
logger,
paneled_text,
progress_bar,
setup_logging,
test_files_progress_bar,
)
__all__ = [
"code_print",
"console",
"logger",
"paneled_text",
"progress_bar",
"setup_logging",
"test_files_progress_bar",
]

View file

@ -1,171 +0,0 @@
"""Rich-based console UI for codeflash_core.
Provides spinners, progress bars, panels, code display, and logging.
No LSP or subagent concerns this is the core CLI output layer.
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from itertools import cycle
from typing import TYPE_CHECKING
from rich.console import Console
from rich.highlighter import NullHighlighter
from rich.logging import RichHandler
from rich.panel import Panel
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskID,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.text import Text
if TYPE_CHECKING:
from collections.abc import Generator
# ---------------------------------------------------------------------------
# Console and logging
# ---------------------------------------------------------------------------
console = Console(highlighter=NullHighlighter())
logger = logging.getLogger("codeflash")
def setup_logging(level: int = logging.INFO) -> None:
"""Configure root logger with Rich handler. Call from CLI, not at import time."""
logging.basicConfig(
level=level,
handlers=[
RichHandler(
rich_tracebacks=True,
markup=False,
highlighter=NullHighlighter(),
console=console,
show_path=False,
show_time=False,
)
],
format="%(message)s",
)
# ---------------------------------------------------------------------------
# Spinners
# ---------------------------------------------------------------------------
SPINNER_TYPES = [
"dots",
"dots2",
"dots3",
"dots4",
"dots5",
"line",
"line2",
"arc",
"circle",
"star",
"star2",
"moon",
"bouncingBar",
"bouncingBall",
"flip",
"growVertical",
"growHorizontal",
"balloon",
"noise",
"bounce",
"point",
"layer",
"betaWave",
]
spinner_cycle = cycle(SPINNER_TYPES)
# ---------------------------------------------------------------------------
# Dummy types for fallback
# ---------------------------------------------------------------------------
class DummyTask:
def __init__(self) -> None:
self.id: TaskID = TaskID(0)
class DummyProgress:
def advance(self, task_id: TaskID, advance: int = 1) -> None:
pass
# ---------------------------------------------------------------------------
# Progress bars
# ---------------------------------------------------------------------------
progress_bar_active = False
@contextmanager
def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]:
"""Spinner with elapsed time. Avoids nesting Rich Live displays."""
global progress_bar_active
if progress_bar_active:
yield DummyTask().id
return
progress_bar_active = True
try:
progress = Progress(
SpinnerColumn(next(spinner_cycle)),
*Progress.get_default_columns(),
TimeElapsedColumn(),
console=console,
transient=transient,
)
task = progress.add_task(message, total=None)
with progress:
yield task
finally:
progress_bar_active = False
@contextmanager
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
"""Progress bar with M/N counter for test files."""
with Progress(
SpinnerColumn(next(spinner_cycle)),
TextColumn("[progress.description]{task.description}"),
BarColumn(complete_style="cyan", finished_style="green", pulse_style="yellow"),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=console,
transient=True,
) as progress:
task_id = progress.add_task(description, total=total)
yield progress, task_id
# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------
def paneled_text(text: str, *, title: str = "", border_style: str = "cyan") -> None:
"""Print text inside a bordered panel."""
console.print(Panel(Text(text), title=title or None, border_style=border_style))
def code_print(code_str: str, *, language: str = "python") -> None:
"""Print code with syntax highlighting."""
from rich.syntax import Syntax
console.rule()
console.print(Syntax(code_str, language, line_numbers=True, theme="github-dark"))
console.rule()

View file

@ -1,54 +0,0 @@
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash_core.models import TestOutcome, TestResults
# Type alias for a custom output comparator function.
OutputComparator = Callable[[object, object], bool]
def is_equivalent(baseline: TestResults, candidate: TestResults, comparator: OutputComparator | None = None) -> bool:
"""Check if candidate test results are equivalent to baseline.
Both must pass, have the same number of outcomes, and each outcome
must match on test_id, status, and output.
If *comparator* is provided it is used instead of ``==`` to compare
captured test outputs.
"""
if not baseline.passed or not candidate.passed:
return False
if len(baseline.outcomes) != len(candidate.outcomes):
return False
baseline_by_id = {o.test_id: o for o in baseline.outcomes}
candidate_by_id = {o.test_id: o for o in candidate.outcomes}
if baseline_by_id.keys() != candidate_by_id.keys():
return False
return all(
outcomes_match(baseline_by_id[test_id], candidate_by_id[test_id], comparator=comparator)
for test_id in baseline_by_id
)
def outcomes_match(baseline: TestOutcome, candidate: TestOutcome, comparator: OutputComparator | None = None) -> bool:
if baseline.status != candidate.status:
return False
if baseline.output is not None and candidate.output is not None:
return compare_outputs(baseline.output, candidate.output, comparator=comparator)
return True
def compare_outputs(
baseline_output: object, candidate_output: object, comparator: OutputComparator | None = None
) -> bool:
"""Compare return-value outputs using *comparator*, falling back to ``==``."""
if comparator is not None:
return comparator(baseline_output, candidate_output)
return baseline_output == candidate_output