feat: add --script mode to codeflash compare

Allows running arbitrary benchmark scripts on both git refs and
rendering a styled comparison table. Supports optional --memory
via memray wrapping. No codeflash config required for script mode.
This commit is contained in:
Kevin Turcios 2026-04-02 11:36:54 -05:00
parent ca198ce5ab
commit 6965e9871d
4 changed files with 491 additions and 30 deletions

View file

@ -145,6 +145,63 @@ class CompareResult:
return "\n\n".join(sections)
@dataclass
class ScriptCompareResult:
base_ref: str
head_ref: str
base_results: dict[str, float] = field(default_factory=dict)
head_results: dict[str, float] = field(default_factory=dict)
base_memory: Optional[MemoryStats] = None
head_memory: Optional[MemoryStats] = None
def format_markdown(self) -> str:
if not self.base_results and not self.head_results and not self.base_memory and not self.head_memory:
return "_No benchmark results to compare._"
base_short = self.base_ref[:12]
head_short = self.head_ref[:12]
lines: list[str] = [f"## Benchmark: `{base_short}` vs `{head_short}`"]
all_keys = sorted((set(self.base_results) | set(self.head_results)) - {"__total__"})
has_total = "__total__" in self.base_results or "__total__" in self.head_results
lines.extend(["", "| Key | Base | Head | Delta | Speedup |", "|:---|---:|---:|:---|---:|"])
for key in all_keys:
b = self.base_results.get(key)
h = self.head_results.get(key)
lines.append(
f"| `{key}` | {_fmt_seconds(b)} | {_fmt_seconds(h)} | {_md_delta_s(b, h)} | {md_speedup(b, h)} |"
)
if has_total:
b = self.base_results.get("__total__")
h = self.head_results.get("__total__")
lines.append(
f"| **TOTAL** | **{_fmt_seconds(b)}** | **{_fmt_seconds(h)}** | {_md_delta_s(b, h)} | {md_speedup(b, h)} |"
)
if self.base_memory or self.head_memory:
lines.extend(
["", "#### Memory", "", "| Ref | Peak Memory | Allocations | Delta |", "|:---|---:|---:|:---|"]
)
if self.base_memory:
lines.append(
f"| `{base_short}` (base) | {md_bytes(self.base_memory.peak_memory_bytes)}"
f" | {self.base_memory.total_allocations:,} | |"
)
if self.head_memory:
delta = md_memory_delta(
self.base_memory.peak_memory_bytes if self.base_memory else None, self.head_memory.peak_memory_bytes
)
lines.append(
f"| `{head_short}` (head) | {md_bytes(self.head_memory.peak_memory_bytes)}"
f" | {self.head_memory.total_allocations:,} | {delta} |"
)
lines.extend(["", "---", "*Generated by codeflash optimization agent*"])
return "\n".join(lines)
def compare_branches(
base_ref: str,
head_ref: str,
@ -837,3 +894,289 @@ def has_meaningful_memory_change(
if alloc_pct > threshold_pct:
return True
return False
# --- Script-mode comparison ---
def _fmt_seconds(s: Optional[float]) -> str:
if s is None:
return "-"
if s >= 60:
return f"{s / 60:,.1f}m"
return f"{s:,.2f}s"
def _fmt_delta_s(before: Optional[float], after: Optional[float]) -> str:
if before is None or after is None:
return "-"
pct = ((after - before) / before) * 100 if before != 0 else 0
if pct < 0:
return _GREEN_TPL % pct
return _RED_TPL % pct
def _md_delta_s(before: Optional[float], after: Optional[float]) -> str:
if before is None or after is None or before == 0:
return "-"
pct = ((after - before) / before) * 100
emoji = "\U0001f7e2" if pct <= 0 else "\U0001f534"
return f"{emoji} {pct:+.1f}%"
def _speedup_s(before: Optional[float], after: Optional[float]) -> str:
if before is None or after is None or after == 0:
return "-"
ratio = before / after
if ratio >= 1:
return f"[green]{ratio:.2f}x[/green]"
return f"[red]{ratio:.2f}x[/red]"
def compare_with_script(
base_ref: str,
head_ref: str,
project_root: Path,
script_cmd: str,
script_output: str,
timeout: int = 600,
memory: bool = False,
) -> ScriptCompareResult:
"""Compare benchmark performance between two git refs using a custom script.
The script is run in each worktree with CWD set to the worktree root.
It must produce a JSON file at script_output (relative to worktree root)
mapping keys to seconds, e.g. {"test1": 1.23, "__total__": 4.56}.
"""
import sys
if memory and sys.platform == "win32":
logger.error("--memory requires memray which is not available on Windows")
return ScriptCompareResult(base_ref=base_ref, head_ref=head_ref)
repo = git.Repo(project_root, search_parent_directories=True)
from codeflash.code_utils.git_worktree_utils import worktree_dirs
worktree_dirs.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d-%H%M%S")
base_worktree = worktree_dirs / f"compare-base-{timestamp}"
head_worktree = worktree_dirs / f"compare-head-{timestamp}"
base_memray_bin = worktree_dirs / f"script-memray-base-{timestamp}.bin"
head_memray_bin = worktree_dirs / f"script-memray-head-{timestamp}.bin"
result = ScriptCompareResult(base_ref=base_ref, head_ref=head_ref)
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.text import Text
base_short = base_ref[:12]
head_short = head_ref[:12]
step_labels = [
"Creating worktrees",
f"Running benchmark on base ({base_short})",
f"Running benchmark on head ({head_short})",
]
def build_steps(current_step: int) -> Group:
lines: list[Text] = []
for i, label in enumerate(step_labels):
if i < current_step:
lines.append(Text.from_markup(f"[green]\u2714[/green] {label}"))
elif i == current_step:
lines.append(Text.from_markup(f"[cyan]\u25cb[/cyan] {label}..."))
else:
lines.append(Text.from_markup(f"[dim]\u2500 {label}[/dim]"))
return Group(*lines)
def build_panel(current_step: int) -> Panel:
return Panel(
Group(
Text.from_markup(
f"[bold cyan]{base_short}[/bold cyan] (base) vs [bold cyan]{head_short}[/bold cyan] (head)"
),
"",
Text.from_markup(f"[dim]Script:[/dim] {script_cmd}"),
"",
build_steps(current_step),
),
title="[bold]Script Benchmark Compare[/bold]",
border_style="cyan",
expand=True,
padding=(1, 2),
)
try:
step = 0
with Live(build_panel(step), console=console, refresh_per_second=1) as live:
base_sha = repo.commit(base_ref).hexsha
head_sha = repo.commit(head_ref).hexsha
repo.git.worktree("add", str(base_worktree), base_sha)
repo.git.worktree("add", str(head_worktree), head_sha)
step += 1
live.update(build_panel(step))
# Run script on base
result.base_results = _run_script_in_worktree(
script_cmd, base_worktree, script_output, timeout, base_memray_bin if memory else None
)
step += 1
live.update(build_panel(step))
# Run script on head
result.head_results = _run_script_in_worktree(
script_cmd, head_worktree, script_output, timeout, head_memray_bin if memory else None
)
# Parse memory results
if memory:
result.base_memory = _parse_memray_bin(base_memray_bin)
result.head_memory = _parse_memray_bin(head_memray_bin)
render_script_comparison(result)
except KeyboardInterrupt:
console.print("\n[yellow]Interrupted — cleaning up...[/yellow]")
finally:
from codeflash.code_utils.git_worktree_utils import remove_worktree
remove_worktree(base_worktree)
remove_worktree(head_worktree)
repo.git.worktree("prune")
for f in [base_memray_bin, head_memray_bin]:
if f.exists():
f.unlink()
return result
def _run_script_in_worktree(
script_cmd: str, worktree_dir: Path, script_output: str, timeout: int, memray_bin: Optional[Path]
) -> dict[str, float]:
import json
cmd = script_cmd
if memray_bin:
cmd = f"python -m memray run --trace-python-allocators -o {memray_bin} -- {cmd}"
try:
proc = subprocess.run( # noqa: S602
cmd, shell=True, cwd=worktree_dir, timeout=timeout, capture_output=True, text=True, check=False
)
if proc.returncode != 0:
logger.warning(f"Script exited with code {proc.returncode}")
if proc.stderr:
logger.debug(f"Script stderr:\n{proc.stderr[:2000]}")
except subprocess.TimeoutExpired:
logger.warning(f"Script timed out after {timeout}s")
return {}
output_path = worktree_dir / script_output
if not output_path.exists():
logger.warning(f"Script output not found at {output_path}")
return {}
try:
data = json.loads(output_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
logger.warning("Script output JSON is not a dict")
return {}
return {k: float(v) for k, v in data.items() if isinstance(v, (int, float))}
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Failed to parse script output JSON: {e}")
return {}
def _parse_memray_bin(bin_path: Path) -> Optional[MemoryStats]:
if not bin_path.exists():
return None
try:
from memray import FileReader
from codeflash.benchmarking.plugin.plugin import MemoryStats
reader = FileReader(str(bin_path))
meta = reader.metadata
stats = MemoryStats(peak_memory_bytes=meta.peak_memory, total_allocations=meta.total_allocations)
reader.close()
return stats
except ImportError:
logger.warning("memray not installed — skipping memory results")
return None
except OSError as e:
logger.warning(f"Failed to read memray binary: {e}")
return None
def render_script_comparison(result: ScriptCompareResult) -> None:
has_timing = result.base_results or result.head_results
has_memory = result.base_memory or result.head_memory
if not has_timing and not has_memory:
logger.warning("No benchmark results to compare")
return
base_short = result.base_ref[:12]
head_short = result.head_ref[:12]
console.print()
console.rule(f"[bold]Script Benchmark: {base_short} vs {head_short}[/bold]")
console.print()
if has_timing:
all_keys = sorted((set(result.base_results) | set(result.head_results)) - {"__total__"})
has_total = "__total__" in result.base_results or "__total__" in result.head_results
t = Table(title="Benchmark Results", border_style="blue", show_lines=True, expand=False)
t.add_column("Key", style="cyan")
t.add_column("Base", justify="right", style="yellow")
t.add_column("Head", justify="right", style="yellow")
t.add_column("Delta", justify="right")
t.add_column("Speedup", justify="right")
for key in all_keys:
b = result.base_results.get(key)
h = result.head_results.get(key)
t.add_row(key, _fmt_seconds(b), _fmt_seconds(h), _fmt_delta_s(b, h), _speedup_s(b, h))
if has_total:
t.add_section()
b = result.base_results.get("__total__")
h = result.head_results.get("__total__")
t.add_row("[bold]TOTAL[/bold]", _fmt_seconds(b), _fmt_seconds(h), _fmt_delta_s(b, h), _speedup_s(b, h))
console.print(t, justify="center")
if has_memory:
console.print()
t_mem = Table(title="Memory (aggregate)", border_style="magenta", show_lines=True, expand=False)
t_mem.add_column("Ref", style="bold cyan")
t_mem.add_column("Peak Memory", justify="right")
t_mem.add_column("Allocations", justify="right")
t_mem.add_column("Delta", justify="right")
if result.base_memory:
t_mem.add_row(
f"{base_short} (base)",
fmt_bytes(result.base_memory.peak_memory_bytes),
f"{result.base_memory.total_allocations:,}",
"",
)
if result.head_memory:
delta = fmt_memory_delta(
result.base_memory.peak_memory_bytes if result.base_memory else None,
result.head_memory.peak_memory_bytes,
)
t_mem.add_row(
f"{head_short} (head)",
fmt_bytes(result.head_memory.peak_memory_bytes),
f"{result.head_memory.total_allocations:,}",
delta,
)
console.print(t_mem, justify="center")
console.print()

View file

@ -395,6 +395,13 @@ def _build_parser() -> ArgumentParser:
compare_parser.add_argument(
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
)
compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree")
compare_parser.add_argument(
"--script-output",
type=str,
dest="script_output",
help="Relative path to JSON results file produced by --script (required with --script)",
)
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")

View file

@ -13,31 +13,10 @@ if TYPE_CHECKING:
from codeflash.models.function_types import FunctionToOptimize
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_parser import parse_config_file
def run_compare(args: Namespace) -> None:
"""Entry point for the compare subcommand."""
# Load project config
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
module_root = Path(pyproject_config.get("module_root", ".")).resolve()
tests_root = Path(pyproject_config.get("tests_root", "tests")).resolve()
benchmarks_root_str = pyproject_config.get("benchmarks_root")
if not benchmarks_root_str:
logger.error("benchmarks-root must be configured in [tool.codeflash] to use compare")
sys.exit(1)
benchmarks_root = Path(benchmarks_root_str).resolve()
if not benchmarks_root.is_dir():
logger.error(f"benchmarks-root {benchmarks_root} is not a valid directory")
sys.exit(1)
from codeflash.cli_cmds.cli import project_root_from_module_root
project_root = project_root_from_module_root(module_root, pyproject_file_path)
# Resolve head_ref: explicit arg > --pr > current branch
head_ref = args.head_ref
if args.pr:
@ -58,6 +37,61 @@ def run_compare(args: Namespace) -> None:
sys.exit(1)
logger.info(f"Auto-detected base ref: {base_ref}")
# Script mode: run an arbitrary benchmark command on each worktree (no codeflash config needed)
script_cmd = getattr(args, "script", None)
if script_cmd:
script_output = getattr(args, "script_output", None)
if not script_output:
logger.error("--script-output is required when using --script")
sys.exit(1)
import git
project_root = Path(git.Repo(Path.cwd(), search_parent_directories=True).working_dir)
from codeflash.benchmarking.compare import compare_with_script
result = compare_with_script(
base_ref=base_ref,
head_ref=head_ref,
project_root=project_root,
script_cmd=script_cmd,
script_output=script_output,
timeout=args.timeout,
memory=getattr(args, "memory", False),
)
if not result.base_results and not result.head_results:
logger.warning("No benchmark data collected. Check that --script-output points to a valid JSON file.")
sys.exit(1)
if args.output:
md = result.format_markdown()
Path(args.output).write_text(md, encoding="utf-8")
logger.info(f"Markdown report written to {args.output}")
return
# Standard trace-benchmark mode: requires codeflash config
from codeflash.code_utils.config_parser import parse_config_file
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
module_root = Path(pyproject_config.get("module_root", ".")).resolve()
from codeflash.cli_cmds.cli import project_root_from_module_root
project_root = project_root_from_module_root(module_root, pyproject_file_path)
tests_root = Path(pyproject_config.get("tests_root", "tests")).resolve()
benchmarks_root_str = pyproject_config.get("benchmarks_root")
if not benchmarks_root_str:
logger.error("benchmarks-root must be configured in [tool.codeflash] to use compare")
sys.exit(1)
benchmarks_root = Path(benchmarks_root_str).resolve()
if not benchmarks_root.is_dir():
logger.error(f"benchmarks-root {benchmarks_root} is not a valid directory")
sys.exit(1)
# Parse explicit functions if provided
functions = None
if args.functions:

View file

@ -1,6 +1,12 @@
from __future__ import annotations
from codeflash.benchmarking.compare import CompareResult, has_meaningful_memory_change, render_comparison
from codeflash.benchmarking.compare import (
CompareResult,
ScriptCompareResult,
has_meaningful_memory_change,
render_comparison,
render_script_comparison,
)
from codeflash.benchmarking.plugin.plugin import BenchmarkStats, MemoryStats
from codeflash.models.models import BenchmarkKey
@ -101,14 +107,8 @@ class TestFormatMarkdownMemoryOnly:
head_ref="def456",
base_stats={timing_key: _make_stats()},
head_stats={timing_key: _make_stats(median_ns=500.0)},
base_memory={
timing_key: _make_memory(peak=10_000_000),
memory_key: _make_memory(peak=8_000_000),
},
head_memory={
timing_key: _make_memory(peak=5_000_000),
memory_key: _make_memory(peak=6_000_000),
},
base_memory={timing_key: _make_memory(peak=10_000_000), memory_key: _make_memory(peak=8_000_000)},
head_memory={timing_key: _make_memory(peak=5_000_000), memory_key: _make_memory(peak=6_000_000)},
)
md = result.format_markdown()
@ -161,3 +161,80 @@ class TestHasMeaningfulMemoryChange:
base = _make_memory(peak=10_000_000, allocs=1000)
head = _make_memory(peak=10_000_000, allocs=800)
assert has_meaningful_memory_change(base, head)
class TestScriptCompareResult:
def test_format_markdown_basic(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"file1.pdf": 12.34, "file2.docx": 1.23},
head_results={"file1.pdf": 10.21, "file2.docx": 1.45},
)
md = result.format_markdown()
assert "file1.pdf" in md
assert "file2.docx" in md
assert "Base" in md
assert "Head" in md
def test_format_markdown_empty(self) -> None:
result = ScriptCompareResult(base_ref="abc123", head_ref="def456")
md = result.format_markdown()
assert md == "_No benchmark results to compare._"
def test_format_markdown_total_row(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"test1": 1.0, "__total__": 5.0},
head_results={"test1": 0.8, "__total__": 4.0},
)
md = result.format_markdown()
assert "**TOTAL**" in md
# __total__ should not appear as a regular key row
assert md.count("__total__") == 0
def test_format_markdown_missing_keys(self) -> None:
result = ScriptCompareResult(
base_ref="abc123", head_ref="def456", base_results={"only_base": 2.0}, head_results={"only_head": 3.0}
)
md = result.format_markdown()
assert "only_base" in md
assert "only_head" in md
def test_format_markdown_with_memory(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"test1": 1.0},
head_results={"test1": 0.5},
base_memory=_make_memory(peak=10_000_000, allocs=500),
head_memory=_make_memory(peak=7_000_000, allocs=400),
)
md = result.format_markdown()
assert "Peak Memory" in md
assert "Allocations" in md
def test_render_no_crash(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"a": 1.0, "b": 2.0, "__total__": 3.0},
head_results={"a": 0.5, "b": 1.5, "__total__": 2.0},
)
render_script_comparison(result)
def test_render_empty_no_crash(self) -> None:
result = ScriptCompareResult(base_ref="abc123", head_ref="def456")
render_script_comparison(result)
def test_render_with_memory_no_crash(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"test1": 5.0},
head_results={"test1": 4.0},
base_memory=_make_memory(peak=10_000_000, allocs=1000),
head_memory=_make_memory(peak=8_000_000, allocs=900),
)
render_script_comparison(result)