diff --git a/codeflash/benchmarking/compare.py b/codeflash/benchmarking/compare.py index 11cebea3e..9ce4db01b 100644 --- a/codeflash/benchmarking/compare.py +++ b/codeflash/benchmarking/compare.py @@ -145,6 +145,63 @@ class CompareResult: return "\n\n".join(sections) +@dataclass +class ScriptCompareResult: + base_ref: str + head_ref: str + base_results: dict[str, float] = field(default_factory=dict) + head_results: dict[str, float] = field(default_factory=dict) + base_memory: Optional[MemoryStats] = None + head_memory: Optional[MemoryStats] = None + + def format_markdown(self) -> str: + if not self.base_results and not self.head_results and not self.base_memory and not self.head_memory: + return "_No benchmark results to compare._" + + base_short = self.base_ref[:12] + head_short = self.head_ref[:12] + lines: list[str] = [f"## Benchmark: `{base_short}` vs `{head_short}`"] + + all_keys = sorted((set(self.base_results) | set(self.head_results)) - {"__total__"}) + has_total = "__total__" in self.base_results or "__total__" in self.head_results + + lines.extend(["", "| Key | Base | Head | Delta | Speedup |", "|:---|---:|---:|:---|---:|"]) + for key in all_keys: + b = self.base_results.get(key) + h = self.head_results.get(key) + lines.append( + f"| `{key}` | {_fmt_seconds(b)} | {_fmt_seconds(h)} | {_md_delta_s(b, h)} | {md_speedup(b, h)} |" + ) + + if has_total: + b = self.base_results.get("__total__") + h = self.head_results.get("__total__") + lines.append( + f"| **TOTAL** | **{_fmt_seconds(b)}** | **{_fmt_seconds(h)}** | {_md_delta_s(b, h)} | {md_speedup(b, h)} |" + ) + + if self.base_memory or self.head_memory: + lines.extend( + ["", "#### Memory", "", "| Ref | Peak Memory | Allocations | Delta |", "|:---|---:|---:|:---|"] + ) + if self.base_memory: + lines.append( + f"| `{base_short}` (base) | {md_bytes(self.base_memory.peak_memory_bytes)}" + f" | {self.base_memory.total_allocations:,} | |" + ) + if self.head_memory: + delta = md_memory_delta( + self.base_memory.peak_memory_bytes if self.base_memory else None, self.head_memory.peak_memory_bytes + ) + lines.append( + f"| `{head_short}` (head) | {md_bytes(self.head_memory.peak_memory_bytes)}" + f" | {self.head_memory.total_allocations:,} | {delta} |" + ) + + lines.extend(["", "---", "*Generated by codeflash optimization agent*"]) + return "\n".join(lines) + + def compare_branches( base_ref: str, head_ref: str, @@ -837,3 +894,289 @@ def has_meaningful_memory_change( if alloc_pct > threshold_pct: return True return False + + +# --- Script-mode comparison --- + + +def _fmt_seconds(s: Optional[float]) -> str: + if s is None: + return "-" + if s >= 60: + return f"{s / 60:,.1f}m" + return f"{s:,.2f}s" + + +def _fmt_delta_s(before: Optional[float], after: Optional[float]) -> str: + if before is None or after is None: + return "-" + pct = ((after - before) / before) * 100 if before != 0 else 0 + if pct < 0: + return _GREEN_TPL % pct + return _RED_TPL % pct + + +def _md_delta_s(before: Optional[float], after: Optional[float]) -> str: + if before is None or after is None or before == 0: + return "-" + pct = ((after - before) / before) * 100 + emoji = "\U0001f7e2" if pct <= 0 else "\U0001f534" + return f"{emoji} {pct:+.1f}%" + + +def _speedup_s(before: Optional[float], after: Optional[float]) -> str: + if before is None or after is None or after == 0: + return "-" + ratio = before / after + if ratio >= 1: + return f"[green]{ratio:.2f}x[/green]" + return f"[red]{ratio:.2f}x[/red]" + + +def compare_with_script( + base_ref: str, + head_ref: str, + project_root: Path, + script_cmd: str, + script_output: str, + timeout: int = 600, + memory: bool = False, +) -> ScriptCompareResult: + """Compare benchmark performance between two git refs using a custom script. + + The script is run in each worktree with CWD set to the worktree root. + It must produce a JSON file at script_output (relative to worktree root) + mapping keys to seconds, e.g. {"test1": 1.23, "__total__": 4.56}. + """ + import sys + + if memory and sys.platform == "win32": + logger.error("--memory requires memray which is not available on Windows") + return ScriptCompareResult(base_ref=base_ref, head_ref=head_ref) + + repo = git.Repo(project_root, search_parent_directories=True) + + from codeflash.code_utils.git_worktree_utils import worktree_dirs + + worktree_dirs.mkdir(parents=True, exist_ok=True) + timestamp = time.strftime("%Y%m%d-%H%M%S") + + base_worktree = worktree_dirs / f"compare-base-{timestamp}" + head_worktree = worktree_dirs / f"compare-head-{timestamp}" + base_memray_bin = worktree_dirs / f"script-memray-base-{timestamp}.bin" + head_memray_bin = worktree_dirs / f"script-memray-head-{timestamp}.bin" + + result = ScriptCompareResult(base_ref=base_ref, head_ref=head_ref) + + from rich.console import Group + from rich.live import Live + from rich.panel import Panel + from rich.text import Text + + base_short = base_ref[:12] + head_short = head_ref[:12] + + step_labels = [ + "Creating worktrees", + f"Running benchmark on base ({base_short})", + f"Running benchmark on head ({head_short})", + ] + + def build_steps(current_step: int) -> Group: + lines: list[Text] = [] + for i, label in enumerate(step_labels): + if i < current_step: + lines.append(Text.from_markup(f"[green]\u2714[/green] {label}")) + elif i == current_step: + lines.append(Text.from_markup(f"[cyan]\u25cb[/cyan] {label}...")) + else: + lines.append(Text.from_markup(f"[dim]\u2500 {label}[/dim]")) + return Group(*lines) + + def build_panel(current_step: int) -> Panel: + return Panel( + Group( + Text.from_markup( + f"[bold cyan]{base_short}[/bold cyan] (base) vs [bold cyan]{head_short}[/bold cyan] (head)" + ), + "", + Text.from_markup(f"[dim]Script:[/dim] {script_cmd}"), + "", + build_steps(current_step), + ), + title="[bold]Script Benchmark Compare[/bold]", + border_style="cyan", + expand=True, + padding=(1, 2), + ) + + try: + step = 0 + with Live(build_panel(step), console=console, refresh_per_second=1) as live: + base_sha = repo.commit(base_ref).hexsha + head_sha = repo.commit(head_ref).hexsha + repo.git.worktree("add", str(base_worktree), base_sha) + repo.git.worktree("add", str(head_worktree), head_sha) + step += 1 + live.update(build_panel(step)) + + # Run script on base + result.base_results = _run_script_in_worktree( + script_cmd, base_worktree, script_output, timeout, base_memray_bin if memory else None + ) + step += 1 + live.update(build_panel(step)) + + # Run script on head + result.head_results = _run_script_in_worktree( + script_cmd, head_worktree, script_output, timeout, head_memray_bin if memory else None + ) + + # Parse memory results + if memory: + result.base_memory = _parse_memray_bin(base_memray_bin) + result.head_memory = _parse_memray_bin(head_memray_bin) + + render_script_comparison(result) + + except KeyboardInterrupt: + console.print("\n[yellow]Interrupted — cleaning up...[/yellow]") + + finally: + from codeflash.code_utils.git_worktree_utils import remove_worktree + + remove_worktree(base_worktree) + remove_worktree(head_worktree) + repo.git.worktree("prune") + for f in [base_memray_bin, head_memray_bin]: + if f.exists(): + f.unlink() + + return result + + +def _run_script_in_worktree( + script_cmd: str, worktree_dir: Path, script_output: str, timeout: int, memray_bin: Optional[Path] +) -> dict[str, float]: + import json + + cmd = script_cmd + if memray_bin: + cmd = f"python -m memray run --trace-python-allocators -o {memray_bin} -- {cmd}" + + try: + proc = subprocess.run( # noqa: S602 + cmd, shell=True, cwd=worktree_dir, timeout=timeout, capture_output=True, text=True, check=False + ) + if proc.returncode != 0: + logger.warning(f"Script exited with code {proc.returncode}") + if proc.stderr: + logger.debug(f"Script stderr:\n{proc.stderr[:2000]}") + except subprocess.TimeoutExpired: + logger.warning(f"Script timed out after {timeout}s") + return {} + + output_path = worktree_dir / script_output + if not output_path.exists(): + logger.warning(f"Script output not found at {output_path}") + return {} + + try: + data = json.loads(output_path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + logger.warning("Script output JSON is not a dict") + return {} + return {k: float(v) for k, v in data.items() if isinstance(v, (int, float))} + except (json.JSONDecodeError, ValueError) as e: + logger.warning(f"Failed to parse script output JSON: {e}") + return {} + + +def _parse_memray_bin(bin_path: Path) -> Optional[MemoryStats]: + if not bin_path.exists(): + return None + try: + from memray import FileReader + + from codeflash.benchmarking.plugin.plugin import MemoryStats + + reader = FileReader(str(bin_path)) + meta = reader.metadata + stats = MemoryStats(peak_memory_bytes=meta.peak_memory, total_allocations=meta.total_allocations) + reader.close() + return stats + except ImportError: + logger.warning("memray not installed — skipping memory results") + return None + except OSError as e: + logger.warning(f"Failed to read memray binary: {e}") + return None + + +def render_script_comparison(result: ScriptCompareResult) -> None: + has_timing = result.base_results or result.head_results + has_memory = result.base_memory or result.head_memory + if not has_timing and not has_memory: + logger.warning("No benchmark results to compare") + return + + base_short = result.base_ref[:12] + head_short = result.head_ref[:12] + + console.print() + console.rule(f"[bold]Script Benchmark: {base_short} vs {head_short}[/bold]") + console.print() + + if has_timing: + all_keys = sorted((set(result.base_results) | set(result.head_results)) - {"__total__"}) + has_total = "__total__" in result.base_results or "__total__" in result.head_results + + t = Table(title="Benchmark Results", border_style="blue", show_lines=True, expand=False) + t.add_column("Key", style="cyan") + t.add_column("Base", justify="right", style="yellow") + t.add_column("Head", justify="right", style="yellow") + t.add_column("Delta", justify="right") + t.add_column("Speedup", justify="right") + + for key in all_keys: + b = result.base_results.get(key) + h = result.head_results.get(key) + t.add_row(key, _fmt_seconds(b), _fmt_seconds(h), _fmt_delta_s(b, h), _speedup_s(b, h)) + + if has_total: + t.add_section() + b = result.base_results.get("__total__") + h = result.head_results.get("__total__") + t.add_row("[bold]TOTAL[/bold]", _fmt_seconds(b), _fmt_seconds(h), _fmt_delta_s(b, h), _speedup_s(b, h)) + + console.print(t, justify="center") + + if has_memory: + console.print() + t_mem = Table(title="Memory (aggregate)", border_style="magenta", show_lines=True, expand=False) + t_mem.add_column("Ref", style="bold cyan") + t_mem.add_column("Peak Memory", justify="right") + t_mem.add_column("Allocations", justify="right") + t_mem.add_column("Delta", justify="right") + + if result.base_memory: + t_mem.add_row( + f"{base_short} (base)", + fmt_bytes(result.base_memory.peak_memory_bytes), + f"{result.base_memory.total_allocations:,}", + "", + ) + if result.head_memory: + delta = fmt_memory_delta( + result.base_memory.peak_memory_bytes if result.base_memory else None, + result.head_memory.peak_memory_bytes, + ) + t_mem.add_row( + f"{head_short} (head)", + fmt_bytes(result.head_memory.peak_memory_bytes), + f"{result.head_memory.total_allocations:,}", + delta, + ) + console.print(t_mem, justify="center") + + console.print() diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index beb78d1b5..cf5ca7bdd 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -395,6 +395,13 @@ def _build_parser() -> ArgumentParser: compare_parser.add_argument( "--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)" ) + compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree") + compare_parser.add_argument( + "--script-output", + type=str, + dest="script_output", + help="Relative path to JSON results file produced by --script (required with --script)", + ) compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml") trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.") diff --git a/codeflash/cli_cmds/cmd_compare.py b/codeflash/cli_cmds/cmd_compare.py index 5b4f98378..898af5679 100644 --- a/codeflash/cli_cmds/cmd_compare.py +++ b/codeflash/cli_cmds/cmd_compare.py @@ -13,31 +13,10 @@ if TYPE_CHECKING: from codeflash.models.function_types import FunctionToOptimize from codeflash.cli_cmds.console import logger -from codeflash.code_utils.config_parser import parse_config_file def run_compare(args: Namespace) -> None: """Entry point for the compare subcommand.""" - # Load project config - pyproject_config, pyproject_file_path = parse_config_file(args.config_file) - - module_root = Path(pyproject_config.get("module_root", ".")).resolve() - tests_root = Path(pyproject_config.get("tests_root", "tests")).resolve() - benchmarks_root_str = pyproject_config.get("benchmarks_root") - - if not benchmarks_root_str: - logger.error("benchmarks-root must be configured in [tool.codeflash] to use compare") - sys.exit(1) - - benchmarks_root = Path(benchmarks_root_str).resolve() - if not benchmarks_root.is_dir(): - logger.error(f"benchmarks-root {benchmarks_root} is not a valid directory") - sys.exit(1) - - from codeflash.cli_cmds.cli import project_root_from_module_root - - project_root = project_root_from_module_root(module_root, pyproject_file_path) - # Resolve head_ref: explicit arg > --pr > current branch head_ref = args.head_ref if args.pr: @@ -58,6 +37,61 @@ def run_compare(args: Namespace) -> None: sys.exit(1) logger.info(f"Auto-detected base ref: {base_ref}") + # Script mode: run an arbitrary benchmark command on each worktree (no codeflash config needed) + script_cmd = getattr(args, "script", None) + if script_cmd: + script_output = getattr(args, "script_output", None) + if not script_output: + logger.error("--script-output is required when using --script") + sys.exit(1) + + import git + + project_root = Path(git.Repo(Path.cwd(), search_parent_directories=True).working_dir) + + from codeflash.benchmarking.compare import compare_with_script + + result = compare_with_script( + base_ref=base_ref, + head_ref=head_ref, + project_root=project_root, + script_cmd=script_cmd, + script_output=script_output, + timeout=args.timeout, + memory=getattr(args, "memory", False), + ) + + if not result.base_results and not result.head_results: + logger.warning("No benchmark data collected. Check that --script-output points to a valid JSON file.") + sys.exit(1) + + if args.output: + md = result.format_markdown() + Path(args.output).write_text(md, encoding="utf-8") + logger.info(f"Markdown report written to {args.output}") + return + + # Standard trace-benchmark mode: requires codeflash config + from codeflash.code_utils.config_parser import parse_config_file + + pyproject_config, pyproject_file_path = parse_config_file(args.config_file) + module_root = Path(pyproject_config.get("module_root", ".")).resolve() + + from codeflash.cli_cmds.cli import project_root_from_module_root + + project_root = project_root_from_module_root(module_root, pyproject_file_path) + tests_root = Path(pyproject_config.get("tests_root", "tests")).resolve() + benchmarks_root_str = pyproject_config.get("benchmarks_root") + + if not benchmarks_root_str: + logger.error("benchmarks-root must be configured in [tool.codeflash] to use compare") + sys.exit(1) + + benchmarks_root = Path(benchmarks_root_str).resolve() + if not benchmarks_root.is_dir(): + logger.error(f"benchmarks-root {benchmarks_root} is not a valid directory") + sys.exit(1) + # Parse explicit functions if provided functions = None if args.functions: diff --git a/tests/test_compare.py b/tests/test_compare.py index 8bd1ad400..c51b959d9 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -1,6 +1,12 @@ from __future__ import annotations -from codeflash.benchmarking.compare import CompareResult, has_meaningful_memory_change, render_comparison +from codeflash.benchmarking.compare import ( + CompareResult, + ScriptCompareResult, + has_meaningful_memory_change, + render_comparison, + render_script_comparison, +) from codeflash.benchmarking.plugin.plugin import BenchmarkStats, MemoryStats from codeflash.models.models import BenchmarkKey @@ -101,14 +107,8 @@ class TestFormatMarkdownMemoryOnly: head_ref="def456", base_stats={timing_key: _make_stats()}, head_stats={timing_key: _make_stats(median_ns=500.0)}, - base_memory={ - timing_key: _make_memory(peak=10_000_000), - memory_key: _make_memory(peak=8_000_000), - }, - head_memory={ - timing_key: _make_memory(peak=5_000_000), - memory_key: _make_memory(peak=6_000_000), - }, + base_memory={timing_key: _make_memory(peak=10_000_000), memory_key: _make_memory(peak=8_000_000)}, + head_memory={timing_key: _make_memory(peak=5_000_000), memory_key: _make_memory(peak=6_000_000)}, ) md = result.format_markdown() @@ -161,3 +161,80 @@ class TestHasMeaningfulMemoryChange: base = _make_memory(peak=10_000_000, allocs=1000) head = _make_memory(peak=10_000_000, allocs=800) assert has_meaningful_memory_change(base, head) + + +class TestScriptCompareResult: + def test_format_markdown_basic(self) -> None: + result = ScriptCompareResult( + base_ref="abc123", + head_ref="def456", + base_results={"file1.pdf": 12.34, "file2.docx": 1.23}, + head_results={"file1.pdf": 10.21, "file2.docx": 1.45}, + ) + md = result.format_markdown() + assert "file1.pdf" in md + assert "file2.docx" in md + assert "Base" in md + assert "Head" in md + + def test_format_markdown_empty(self) -> None: + result = ScriptCompareResult(base_ref="abc123", head_ref="def456") + md = result.format_markdown() + assert md == "_No benchmark results to compare._" + + def test_format_markdown_total_row(self) -> None: + result = ScriptCompareResult( + base_ref="abc123", + head_ref="def456", + base_results={"test1": 1.0, "__total__": 5.0}, + head_results={"test1": 0.8, "__total__": 4.0}, + ) + md = result.format_markdown() + assert "**TOTAL**" in md + # __total__ should not appear as a regular key row + assert md.count("__total__") == 0 + + def test_format_markdown_missing_keys(self) -> None: + result = ScriptCompareResult( + base_ref="abc123", head_ref="def456", base_results={"only_base": 2.0}, head_results={"only_head": 3.0} + ) + md = result.format_markdown() + assert "only_base" in md + assert "only_head" in md + + def test_format_markdown_with_memory(self) -> None: + result = ScriptCompareResult( + base_ref="abc123", + head_ref="def456", + base_results={"test1": 1.0}, + head_results={"test1": 0.5}, + base_memory=_make_memory(peak=10_000_000, allocs=500), + head_memory=_make_memory(peak=7_000_000, allocs=400), + ) + md = result.format_markdown() + assert "Peak Memory" in md + assert "Allocations" in md + + def test_render_no_crash(self) -> None: + result = ScriptCompareResult( + base_ref="abc123", + head_ref="def456", + base_results={"a": 1.0, "b": 2.0, "__total__": 3.0}, + head_results={"a": 0.5, "b": 1.5, "__total__": 2.0}, + ) + render_script_comparison(result) + + def test_render_empty_no_crash(self) -> None: + result = ScriptCompareResult(base_ref="abc123", head_ref="def456") + render_script_comparison(result) + + def test_render_with_memory_no_crash(self) -> None: + result = ScriptCompareResult( + base_ref="abc123", + head_ref="def456", + base_results={"test1": 5.0}, + head_results={"test1": 4.0}, + base_memory=_make_memory(peak=10_000_000, allocs=1000), + head_memory=_make_memory(peak=8_000_000, allocs=900), + ) + render_script_comparison(result)