chore: merge main into fixes-for-core-unstructured-experimental

This commit is contained in:
Kevin Turcios 2026-02-21 00:57:33 -05:00
commit c6fbdfa535
137 changed files with 4438 additions and 4106 deletions

4
.codex/config.toml Normal file
View file

@ -0,0 +1,4 @@
[mcp_servers.tessl]
type = "stdio"
command = "tessl"
args = [ "mcp", "start" ]

View file

@ -1,2 +0,0 @@
# Managed by Tessl
tessl:*

12
.gemini/settings.json Normal file
View file

@ -0,0 +1,12 @@
{
"mcpServers": {
"tessl": {
"type": "stdio",
"command": "tessl",
"args": [
"mcp",
"start"
]
}
}
}

View file

@ -1,2 +0,0 @@
# Managed by Tessl
tessl:*

View file

@ -6,20 +6,48 @@ on:
- main
paths:
- 'codeflash/version.py'
- 'codeflash-benchmark/codeflash_benchmark/version.py'
jobs:
publish:
detect-changes:
runs-on: ubuntu-latest
outputs:
codeflash: ${{ steps.filter.outputs.codeflash }}
benchmark: ${{ steps.filter.outputs.benchmark }}
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 2
- name: Detect which packages changed
id: filter
run: |
if git diff --name-only HEAD~1 HEAD | grep -q '^codeflash/version.py$'; then
echo "codeflash=true" >> $GITHUB_OUTPUT
else
echo "codeflash=false" >> $GITHUB_OUTPUT
fi
if git diff --name-only HEAD~1 HEAD | grep -q '^codeflash-benchmark/codeflash_benchmark/version.py$'; then
echo "benchmark=true" >> $GITHUB_OUTPUT
else
echo "benchmark=false" >> $GITHUB_OUTPUT
fi
publish-codeflash:
needs: detect-changes
if: needs.detect-changes.outputs.codeflash == 'true'
runs-on: ubuntu-latest
environment:
name: pypi
permissions:
id-token: write
contents: write # Changed from 'read' to 'write' to allow tag creation
contents: write
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0 # Fetch all history for proper versioning
fetch-depth: 0
- name: Extract version from version.py
id: extract_version
@ -76,4 +104,76 @@ jobs:
prerelease: false
generate_release_notes: true
files: |
dist/*
dist/*
publish-benchmark:
needs: detect-changes
if: needs.detect-changes.outputs.benchmark == 'true'
runs-on: ubuntu-latest
environment:
name: pypi
permissions:
id-token: write
contents: write
steps:
- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
- name: Extract version from version.py
id: extract_version
run: |
VERSION=$(grep -oP '__version__ = "\K[^"]+' codeflash-benchmark/codeflash_benchmark/version.py)
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "tag=benchmark-v$VERSION" >> $GITHUB_OUTPUT
echo "Extracted version: $VERSION"
- name: Check if tag already exists
id: check_tag
run: |
if git rev-parse "${{ steps.extract_version.outputs.tag }}" >/dev/null 2>&1; then
echo "exists=true" >> $GITHUB_OUTPUT
echo "Tag ${{ steps.extract_version.outputs.tag }} already exists, skipping release"
else
echo "exists=false" >> $GITHUB_OUTPUT
echo "Tag ${{ steps.extract_version.outputs.tag }} does not exist, proceeding with release"
fi
- name: Create and push git tag
if: steps.check_tag.outputs.exists == 'false'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git tag -a "${{ steps.extract_version.outputs.tag }}" -m "Release ${{ steps.extract_version.outputs.tag }}"
git push origin "${{ steps.extract_version.outputs.tag }}"
- name: Install uv
if: steps.check_tag.outputs.exists == 'false'
uses: astral-sh/setup-uv@v6
- name: Build
if: steps.check_tag.outputs.exists == 'false'
run: uv build --package codeflash-benchmark
- name: Publish to PyPI
if: steps.check_tag.outputs.exists == 'false'
run: uv publish
- name: Create GitHub Release
if: steps.check_tag.outputs.exists == 'false'
uses: softprops/action-gh-release@v2
with:
tag_name: ${{ steps.extract_version.outputs.tag }}
name: codeflash-benchmark ${{ steps.extract_version.outputs.tag }}
body: |
## What's Changed
Release ${{ steps.extract_version.outputs.version }} of codeflash-benchmark.
**Full Changelog**: https://github.com/${{ github.repository }}/commits/${{ steps.extract_version.outputs.tag }}
draft: false
prerelease: false
generate_release_notes: true
files: |
dist/*

View file

@ -15,9 +15,25 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
include:
- os: ubuntu-latest
python-version: "3.9"
- os: ubuntu-latest
python-version: "3.10"
- os: ubuntu-latest
python-version: "3.11"
- os: ubuntu-latest
python-version: "3.12"
- os: ubuntu-latest
python-version: "3.13"
- os: ubuntu-latest
python-version: "3.14"
- os: windows-latest
python-version: "3.13"
continue-on-error: true
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
env:
PYTHONIOENCODING: utf-8
steps:
- uses: actions/checkout@v4
with:

View file

@ -1,34 +0,0 @@
name: windows-unit-tests
on:
push:
branches: [main]
pull_request:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
windows-unit-tests:
continue-on-error: true
runs-on: windows-latest
env:
PYTHONIOENCODING: utf-8
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
python-version: "3.13"
- name: install dependencies
run: uv sync
- name: Unit tests
run: uv run pytest tests/

View file

@ -28,6 +28,10 @@ Discovery → Ranking → Context Extraction → Test Gen + Optimization → Bas
- **Tracer**: Profiling system that records function call trees and timings (`tracing/`, `tracer.py`)
- **Worktree mode**: Git worktree-based parallel optimization (`--worktree` flag)
## PR Reviews
- GitHub PR comments and review feedback can be stale — they may reference issues already fixed by a later commit. Before acting on review feedback, verify it still applies to the current code. If the issue no longer exists, resolve the conversation in the GitHub UI.
<!-- Section below is auto-generated by `tessl install` - do not edit manually -->
# Agent Rules <!-- tessl-managed -->

File diff suppressed because it is too large Load diff

View file

@ -18,6 +18,6 @@
"@types/node": "^20.0.0",
"codeflash": "file:../../../packages/codeflash",
"typescript": "^5.0.0",
"vitest": "^2.0.0"
"vitest": "^4.0.18"
}
}

View file

@ -1,3 +1,3 @@
"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai."""
__version__ = "0.1.0"
from codeflash_benchmark.version import __version__ as __version__

View file

@ -0,0 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.3.0"

View file

@ -1,6 +1,6 @@
[project]
name = "codeflash-benchmark"
version = "0.2.0"
dynamic = ["version"]
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
requires-python = ">=3.9"
@ -25,8 +25,19 @@ Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
codeflash-benchmark = "codeflash_benchmark.plugin"
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
[tool.setuptools]
packages = ["codeflash_benchmark"]
[tool.hatch.version]
source = "uv-dynamic-versioning"
[tool.uv-dynamic-versioning]
enable = true
style = "pep440"
vcs = "git"
[tool.hatch.build.hooks.version]
path = "codeflash_benchmark/version.py"
template = """# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "{version}"
"""

View file

@ -130,9 +130,18 @@ def parse_args() -> Namespace:
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
parser.add_argument(
"--subagent",
action="store_true",
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
)
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
if args.subagent:
args.yes = True
args.no_pr = True
args.worktree = True
return process_and_validate_cmd_args(args)
@ -352,32 +361,52 @@ def _handle_show_config() -> None:
from codeflash.setup.detector import detect_project, has_existing_config
project_root = Path.cwd()
detected = detect_project(project_root)
config_exists, _ = has_existing_config(project_root)
# Check if config exists or is auto-detected
config_exists, config_file = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"
if config_exists:
from codeflash.code_utils.config_parser import parse_config_file
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
if config_exists and config_file:
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
console.print()
config, config_file_path = parse_config_file()
status = "Saved config"
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print(f"[dim]Config file: {config_file_path}[/dim]")
console.print()
table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
table.add_row("Project root", str(project_root))
table.add_row("Module root", config.get("module_root", "(not set)"))
table.add_row("Tests root", config.get("tests_root", "(not set)"))
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)")
ignore_paths = config.get("ignore_paths", [])
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
else:
detected = detect_project(project_root)
status = "Auto-detected (not saved)"
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print()
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")
console.print(table)
console.print()

View file

@ -1,12 +1,14 @@
from __future__ import annotations
import logging
from collections import deque
from contextlib import contextmanager
from itertools import cycle
from typing import TYPE_CHECKING, Optional
from rich.console import Console
from rich.logging import RichHandler
from rich.panel import Panel
from rich.progress import (
BarColumn,
MofNCompleteColumn,
@ -19,43 +21,73 @@ from rich.progress import (
from codeflash.cli_cmds.console_constants import SPINNER_TYPES
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode
from codeflash.lsp.lsp_logger import enhanced_log
from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage
if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Callable, Generator
from pathlib import Path
from rich.progress import TaskID
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import DependencyResolver, IndexResult
from codeflash.lsp.lsp_message import LspMessage
from codeflash.models.models import TestResults
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
console = Console()
if is_LSP_enabled():
if is_LSP_enabled() or is_subagent_mode():
console.quiet = True
logging.basicConfig(
level=logging.INFO,
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
format=BARE_LOGGING_FORMAT,
)
if is_subagent_mode():
import re
import sys
_lsp_prefix_re = re.compile(r"^(?:!?lsp,?|h[2-4]|loading)\|")
_subagent_drop_patterns = (
"Test log -",
"Test failed to load",
"Examining file ",
"Generated ",
"Add custom marker",
"Disabling all autouse",
"Reverting code and helpers",
)
class _AgentLogFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.msg = _lsp_prefix_re.sub("", str(record.msg))
msg = record.getMessage()
return not any(msg.startswith(p) for p in _subagent_drop_patterns)
_agent_handler = logging.StreamHandler(sys.stderr)
_agent_handler.addFilter(_AgentLogFilter())
logging.basicConfig(level=logging.INFO, handlers=[_agent_handler], format="%(levelname)s: %(message)s")
else:
logging.basicConfig(
level=logging.INFO,
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
format=BARE_LOGGING_FORMAT,
)
logger = logging.getLogger("rich")
logging.getLogger("parso").setLevel(logging.WARNING)
# override the logger to reformat the messages for the lsp
for level in ("info", "debug", "warning", "error"):
real_fn = getattr(logger, level)
setattr(
logger,
level,
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
msg, _real_fn, _level, *args, **kwargs
),
)
if not is_subagent_mode():
for level in ("info", "debug", "warning", "error"):
real_fn = getattr(logger, level)
setattr(
logger,
level,
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
msg, _real_fn, _level, *args, **kwargs
),
)
class DummyTask:
@ -82,6 +114,8 @@ def paneled_text(
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
) -> None:
"""Print text in a panel."""
if is_subagent_mode():
return
from rich.panel import Panel
from rich.text import Text
@ -110,6 +144,8 @@ def code_print(
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
"""
if is_subagent_mode():
return
if is_LSP_enabled():
lsp_log(
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
@ -147,6 +183,10 @@ def progress_bar(
"""
global _progress_bar_active
if is_subagent_mode():
yield DummyTask().id
return
if is_LSP_enabled():
lsp_log(LspTextMessage(text=message, takes_time=True))
yield
@ -178,6 +218,10 @@ def progress_bar(
@contextmanager
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
"""Progress bar for test files."""
if is_subagent_mode():
yield DummyProgress(), DummyTask().id
return
if is_LSP_enabled():
lsp_log(LspTextMessage(text=description, takes_time=True))
dummy_progress = DummyProgress()
@ -196,3 +240,242 @@ def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Pro
) as progress:
task_id = progress.add_task(description, total=total)
yield progress, task_id
MAX_TREE_ENTRIES = 8
@contextmanager
def call_graph_live_display(
total: int, project_root: Path | None = None
) -> Generator[Callable[[IndexResult], None], None, None]:
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.text import Text
from rich.tree import Tree
if is_subagent_mode():
yield lambda _: None
return
if is_LSP_enabled():
lsp_log(LspTextMessage(text="Building call graph", takes_time=True))
yield lambda _: None
return
progress = Progress(
SpinnerColumn(next(spinners)),
TextColumn("[progress.description]{task.description}"),
BarColumn(complete_style="cyan", finished_style="green", pulse_style="yellow"),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
auto_refresh=False,
)
task_id = progress.add_task("Analyzing files", total=total)
results: deque[IndexResult] = deque(maxlen=MAX_TREE_ENTRIES)
stats = {"indexed": 0, "cached": 0, "edges": 0, "external": 0, "errors": 0}
tree = Tree("[bold]Recent Files[/bold]")
stats_text = Text("0 calls found", style="dim")
panel = Panel(
Group(progress, Text(""), tree, Text(""), stats_text), title="Building Call Graph", border_style="cyan"
)
def create_tree_node(result: IndexResult) -> Tree:
if project_root:
try:
name = str(result.file_path.resolve().relative_to(project_root.resolve()))
except ValueError:
name = f"{result.file_path.parent.name}/{result.file_path.name}"
else:
name = f"{result.file_path.parent.name}/{result.file_path.name}"
if result.error:
return Tree(f"[red]{name} (error)[/red]")
if result.cached:
return Tree(f"[dim]{name} (cached)[/dim]")
local_edges = result.num_edges - result.cross_file_edges
edge_info = []
if local_edges:
edge_info.append(f"{local_edges} calls in same file")
if result.cross_file_edges:
edge_info.append(f"{result.cross_file_edges} calls from other modules")
label = ", ".join(edge_info) if edge_info else "no calls"
return Tree(f"[cyan]{name}[/cyan] [dim]{label}[/dim]")
def refresh_display() -> None:
tree.children = [create_tree_node(r) for r in results]
tree.children.extend([Tree(" ")] * (MAX_TREE_ENTRIES - len(results)))
# Update stats
stat_parts = []
if stats["indexed"]:
stat_parts.append(f"{stats['indexed']} files analyzed")
if stats["cached"]:
stat_parts.append(f"{stats['cached']} cached")
if stats["errors"]:
stat_parts.append(f"{stats['errors']} errors")
stat_parts.append(f"{stats['edges']} calls found")
if stats["external"]:
stat_parts.append(f"{stats['external']} cross-file calls")
stats_text.truncate(0)
stats_text.append(" · ".join(stat_parts), style="dim")
batch: list[IndexResult] = []
def process_batch() -> None:
for result in batch:
results.append(result)
if result.error:
stats["errors"] += 1
elif result.cached:
stats["cached"] += 1
else:
stats["indexed"] += 1
stats["edges"] += result.num_edges
stats["external"] += result.cross_file_edges
progress.advance(task_id)
batch.clear()
refresh_display()
live.refresh()
def update(result: IndexResult) -> None:
batch.append(result)
if len(batch) >= 8:
process_batch()
with Live(panel, console=console, transient=False, auto_refresh=False) as live:
yield update
if batch:
process_batch()
def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path, list[FunctionToOptimize]]) -> None:
total_functions = sum(map(len, file_to_funcs.values()))
if not total_functions:
return
if is_subagent_mode():
return
# Build the mapping expected by the dependency resolver
file_items = file_to_funcs.items()
mapping = {file_path: {func.qualified_name for func in funcs} for file_path, funcs in file_items}
callee_counts = call_graph.count_callees_per_function(mapping)
# Use built-in sum for C-level loops to reduce Python overhead
total_callees = sum(callee_counts.values())
with_context = sum(1 for count in callee_counts.values() if count > 0)
leaf_functions = total_functions - with_context
avg_callees = total_callees / total_functions
function_label = "function" if total_functions == 1 else "functions"
summary = (
f"{total_functions} {function_label} ready for optimization\n"
f"Uses other functions: {with_context} · "
f"Standalone: {leaf_functions}"
)
if is_LSP_enabled():
lsp_log(LspTextMessage(text=summary))
return
console.print(Panel(summary, title="Call Graph Summary", border_style="cyan"))
def subagent_log_optimization_result(
function_name: str,
file_path: Path,
perf_improvement_line: str,
original_runtime_ns: int,
best_runtime_ns: int,
raw_explanation: str,
original_code: dict[Path, str],
new_code: dict[Path, str],
review: str,
test_results: TestResults,
) -> None:
import sys
from xml.sax.saxutils import escape
from codeflash.code_utils.code_utils import unified_diff_strings
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.models.test_type import TestType
diff_parts = []
for path in original_code:
old = original_code.get(path, "")
new = new_code.get(path, "")
if old != new:
diff = unified_diff_strings(old, new, fromfile=str(path), tofile=str(path))
if diff:
diff_parts.append(diff)
diff_str = "\n".join(diff_parts)
original_runtime = humanize_runtime(original_runtime_ns)
optimized_runtime = humanize_runtime(best_runtime_ns)
report = test_results.get_test_pass_fail_report_by_type()
verification_rows = []
for test_type in TestType:
if test_type is TestType.INIT_STATE_TEST:
continue
name = test_type.to_name()
if not name:
continue
passed = report[test_type]["passed"]
failed = report[test_type]["failed"]
if passed == 0 and failed == 0:
status = "None Found"
elif failed > 0:
status = f"{failed} Failed, {passed} Passed"
else:
status = f"{passed} Passed"
verification_rows.append(f' <test type="{escape(name)}" status="{escape(status)}"/>')
xml = [
"<codeflash-optimization>",
f" <function>{escape(function_name)}</function>",
f" <file>{escape(str(file_path))}</file>",
f" <performance>{escape(perf_improvement_line)}</performance>",
f" <original-runtime>{escape(original_runtime)}</original-runtime>",
f" <optimized-runtime>{escape(optimized_runtime)}</optimized-runtime>",
]
if review:
xml.append(f" <reviewer-confidence>{escape(review)}</reviewer-confidence>")
xml.append(f" <explanation>{escape(raw_explanation.strip())}</explanation>")
xml.append(" <verification>")
xml.extend(verification_rows)
xml.append(" </verification>")
if diff_str:
xml.append(f" <diff>{escape(diff_str)}</diff>")
for path in new_code:
if new_code[path] != original_code.get(path, ""):
xml.append(f' <optimized-code file="{escape(str(path))}">{escape(new_code[path])}</optimized-code>')
xml.append(" <action>")
xml.append(" Present this optimization to the user using AskUserQuestion.")
xml.append(" Show the function name, performance improvement, and explanation as the question text.")
xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).")
xml.append(
" Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes."
)
xml.append(" If the user chooses 'Apply', write the content from optimized-code to the corresponding file.")
xml.append(" </action>")
xml.append("</codeflash-optimization>")
sys.stdout.write("\n".join(xml) + "\n")

View file

@ -128,7 +128,7 @@ def determine_js_package_manager(project_root: Path) -> JsPackageManager:
"""
# Search from project_root up to filesystem root for lock files
# This supports monorepo setups where lock file is at workspace root
current_dir = project_root.resolve()
current_dir = project_root
while current_dir != current_dir.parent:
if (current_dir / "bun.lockb").exists() or (current_dir / "bun.lock").exists():
return JsPackageManager.BUN
@ -161,7 +161,7 @@ def find_node_modules_with_package(project_root: Path, package_name: str) -> Pat
Path to the node_modules directory containing the package, or None if not found.
"""
current_dir = project_root.resolve()
current_dir = project_root
while current_dir != current_dir.parent:
node_modules = current_dir / "node_modules"
if node_modules.exists():

View file

@ -5,8 +5,18 @@ BARE_LOGGING_FORMAT = "%(message)s"
def set_level(level: int, *, echo_setting: bool = True) -> None:
import logging
import sys
import time
from codeflash.lsp.helpers import is_subagent_mode
if is_subagent_mode():
logging.basicConfig(
level=level, handlers=[logging.StreamHandler(sys.stderr)], format="%(levelname)s: %(message)s", force=True
)
logging.getLogger().setLevel(level)
return
from rich.logging import RichHandler
from codeflash.cli_cmds.console import console

View file

@ -141,12 +141,18 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
previous_checkpoint_functions = None
if getattr(args, "subagent", False):
console.rule()
return None
if args.all and codeflash_temp_dir.is_dir():
previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir)
if previous_checkpoint_functions and Confirm.ask(
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
default=True,
console=console,
if previous_checkpoint_functions and (
getattr(args, "yes", False)
or Confirm.ask(
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
default=True,
console=console,
)
):
console.rule()
else:

View file

@ -2,46 +2,16 @@ import os
import sys
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from platformdirs import user_config_dir
if TYPE_CHECKING:
codeflash_temp_dir: Path
codeflash_cache_dir: Path
codeflash_cache_db: Path
LF: str = os.linesep
IS_POSIX: bool = os.name != "nt"
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
codeflash_cache_dir: Path = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
class Compat:
# os-independent newline
LF: str = os.linesep
codeflash_temp_dir: Path = Path(tempfile.gettempdir()) / "codeflash"
codeflash_temp_dir.mkdir(parents=True, exist_ok=True)
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
IS_POSIX: bool = os.name != "nt"
@property
def codeflash_cache_dir(self) -> Path:
return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
@property
def codeflash_temp_dir(self) -> Path:
temp_dir = Path(tempfile.gettempdir()) / "codeflash"
if not temp_dir.exists():
temp_dir.mkdir(parents=True, exist_ok=True)
return temp_dir
@property
def codeflash_cache_db(self) -> Path:
return self.codeflash_cache_dir / "codeflash_cache.db"
_compat = Compat()
codeflash_temp_dir = _compat.codeflash_temp_dir
codeflash_cache_dir = _compat.codeflash_cache_dir
codeflash_cache_db = _compat.codeflash_cache_db
LF = _compat.LF
SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE
IS_POSIX = _compat.IS_POSIX
codeflash_cache_db: Path = codeflash_cache_dir / "codeflash_cache.db"

View file

@ -4,8 +4,8 @@ from enum import Enum
from typing import Any, Union
MAX_TEST_RUN_ITERATIONS = 5
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000
TESTGEN_CONTEXT_TOKEN_LIMIT = 48000
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 64000
TESTGEN_CONTEXT_TOKEN_LIMIT = 64000
INDIVIDUAL_TESTCASE_TIMEOUT = 15
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05

View file

@ -709,6 +709,7 @@ def inject_profiling_into_existing_test(
tests_project_root: Path,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
tests_project_root = tests_project_root.resolve()
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, mode

View file

@ -1,5 +1,7 @@
from __future__ import annotations
from codeflash.result.critic import performance_gain
def humanize_runtime(time_in_ns: int) -> str:
runtime_human: str = str(time_in_ns)
@ -89,3 +91,13 @@ def format_perf(percentage: float) -> str:
if abs_perc >= 1:
return f"{percentage:.2f}"
return f"{percentage:.3f}"
def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str:
perf_gain = format_perf(
abs(performance_gain(original_runtime_ns=original_time_ns, optimized_runtime_ns=optimized_time_ns) * 100)
)
status = "slower" if optimized_time_ns > original_time_ns else "faster"
return (
f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})"
)

View file

@ -69,8 +69,8 @@ FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")
class TestsCache:
SCHEMA_VERSION = 1 # Increment this when schema changes
def __init__(self, project_root_path: str | Path) -> None:
self.project_root_path = Path(project_root_path).resolve().as_posix()
def __init__(self, project_root_path: Path) -> None:
self.project_root_path = project_root_path.resolve().as_posix()
self.connection = sqlite3.connect(codeflash_cache_db)
self.cur = self.connection.cursor()
@ -728,6 +728,10 @@ def discover_tests_pytest(
logger.debug(f"Pytest collection exit code: {exitcode}")
if pytest_rootdir is not None:
cfg.tests_project_rootdir = Path(pytest_rootdir)
if discover_only_these_tests:
resolved_discover_only = {p.resolve() for p in discover_only_these_tests}
else:
resolved_discover_only = None
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
for test in tests:
if "__replay_test" in test["test_file"]:
@ -737,13 +741,14 @@ def discover_tests_pytest(
else:
test_type = TestType.EXISTING_UNIT_TEST
test_file_path = Path(test["test_file"]).resolve()
test_obj = TestsInFile(
test_file=Path(test["test_file"]),
test_file=test_file_path,
test_class=test["test_class"],
test_function=test["test_function"],
test_type=test_type,
)
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
if resolved_discover_only and test_obj.test_file not in resolved_discover_only:
continue
file_to_test_map[test_obj.test_file].append(test_obj)
# Within these test files, find the project functions they are referring to and return their names/locations

View file

@ -114,38 +114,57 @@ class FunctionVisitor(cst.CSTVisitor):
)
class FunctionWithReturnStatement(ast.NodeVisitor):
def __init__(self, file_path: Path) -> None:
self.functions: list[FunctionToOptimize] = []
self.ast_path: list[FunctionParent] = []
self.file_path: Path = file_path
def visit_FunctionDef(self, node: FunctionDef) -> None:
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
def find_functions_with_return_statement(ast_module: ast.Module, file_path: Path) -> list[FunctionToOptimize]:
results: list[FunctionToOptimize] = []
# (node, parent_path) — iterative DFS avoids RecursionError on deeply nested ASTs
stack: list[tuple[ast.AST, list[FunctionParent]]] = [(ast_module, [])]
while stack:
node, ast_path = stack.pop()
if isinstance(node, (FunctionDef, AsyncFunctionDef)):
if function_has_return_statement(node) and not function_is_a_property(node):
results.append(
FunctionToOptimize(
function_name=node.name,
file_path=file_path,
parents=ast_path[:],
is_async=isinstance(node, AsyncFunctionDef),
)
)
)
def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
super().generic_visit(node)
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.pop()
# Don't recurse into function bodies (matches original visitor behaviour)
continue
child_path = (
[*ast_path, FunctionParent(node.name, node.__class__.__name__)] if isinstance(node, ClassDef) else ast_path
)
for child in reversed(list(ast.iter_child_nodes(node))):
stack.append((child, child_path))
return results
# =============================================================================
# Multi-language support helpers
# =============================================================================
_VCS_EXCLUDES = frozenset({".git", ".hg", ".svn"})
def parse_dir_excludes(patterns: frozenset[str]) -> tuple[frozenset[str], tuple[str, ...], tuple[str, ...]]:
"""Split glob patterns into exact names, prefixes, and suffixes.
Patterns ending with ``*`` become prefix matches, patterns starting with ``*``
become suffix matches, and plain strings become exact matches.
"""
exact: set[str] = set()
prefixes: list[str] = []
suffixes: list[str] = []
for p in patterns:
if p.endswith("*"):
prefixes.append(p[:-1])
elif p.startswith("*"):
suffixes.append(p[1:])
else:
exact.add(p)
return frozenset(exact), tuple(prefixes), tuple(suffixes)
def get_files_for_language(
module_root_path: Path, ignore_paths: list[Path] | None = None, language: Language | None = None
@ -164,37 +183,44 @@ def get_files_for_language(
if ignore_paths is None:
ignore_paths = []
all_patterns: frozenset[str]
if language is not None:
support = get_language_support(language)
extensions = support.file_extensions
all_patterns = support.dir_excludes | _VCS_EXCLUDES
else:
extensions = tuple(get_supported_extensions())
all_patterns = _VCS_EXCLUDES
for lang in Language:
if is_language_supported(lang):
all_patterns = all_patterns | get_language_support(lang).dir_excludes
# Default directory patterns to always exclude for JS/TS
js_ts_default_excludes = {
"node_modules",
"dist",
"build",
".next",
".nuxt",
"coverage",
".cache",
".turbo",
".vercel",
"__pycache__",
}
dir_excludes, prefixes, suffixes = parse_dir_excludes(all_patterns)
files = []
for ext in extensions:
pattern = f"*{ext}"
for file_path in module_root_path.rglob(pattern):
# Check explicit ignore paths
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
continue
# Check default JS/TS excludes in path parts
if any(part in js_ts_default_excludes for part in file_path.parts):
continue
files.append(file_path)
ignore_dirs: set[str] = set()
ignore_files: set[Path] = set()
for p in ignore_paths:
p = Path(p) if not isinstance(p, Path) else p
if p.is_file():
ignore_files.add(p)
else:
ignore_dirs.add(str(p))
files: list[Path] = []
for dirpath, dirnames, filenames in os.walk(module_root_path):
dirnames[:] = [
d
for d in dirnames
if d not in dir_excludes
and not (prefixes and d.startswith(prefixes))
and not (suffixes and d.endswith(suffixes))
and str(Path(dirpath) / d) not in ignore_dirs
]
for fname in filenames:
if fname.endswith(extensions):
fpath = Path(dirpath, fname)
if fpath not in ignore_files:
files.append(fpath)
return files
@ -237,9 +263,7 @@ def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[Funct
if DEBUG_MODE:
logger.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
functions[file_path] = find_functions_with_return_statement(ast_module, file_path)
return functions
@ -808,6 +832,7 @@ def filter_functions(
*,
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
resolved_project_root = project_root.resolve()
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
blocklist_funcs = get_blocklisted_functions()
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
@ -884,7 +909,7 @@ def filter_functions(
lang_support = get_language_support(Path(file_path))
if lang_support.language == Language.PYTHON:
try:
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
ast.parse(f"import {module_name_from_file_path(Path(file_path), resolved_project_root)}")
except SyntaxError:
malformed_paths_count += 1
continue
@ -906,7 +931,10 @@ def filter_functions(
if previous_checkpoint_functions:
functions_tmp = []
for function in _functions:
if function.qualified_name_with_modules_from_root(project_root) in previous_checkpoint_functions:
if (
function.qualified_name_with_modules_from_root(resolved_project_root)
in previous_checkpoint_functions
):
previous_checkpoint_functions_removed_count += 1
continue
functions_tmp.append(function)
@ -960,12 +988,21 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
# Custom DFS, return True as soon as a Return node is found
stack: list[ast.AST] = [function_node]
stack: list[ast.AST] = list(function_node.body)
while stack:
node = stack.pop()
if isinstance(node, ast.Return):
return True
stack.extend(ast.iter_child_nodes(node))
# Only push child nodes that are statements; Return nodes are statements,
# so this preserves correctness while avoiding unnecessary traversal into expr/Name/etc.
for field in getattr(node, "_fields", ()):
child = getattr(node, field, None)
if isinstance(child, list):
for item in child:
if isinstance(item, ast.stmt):
stack.append(item)
elif isinstance(child, ast.stmt):
stack.append(child)
return False

View file

@ -19,7 +19,9 @@ Usage:
from codeflash.languages.base import (
CodeContext,
DependencyResolver,
HelperFunction,
IndexResult,
Language,
LanguageSupport,
ParentInfo,
@ -82,8 +84,10 @@ def __getattr__(name: str):
__all__ = [
"CodeContext",
"DependencyResolver",
"FunctionInfo",
"HelperFunction",
"IndexResult",
"Language",
"LanguageSupport",
"ParentInfo",

View file

@ -11,10 +11,11 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
from codeflash.languages.language_enum import Language
from codeflash.models.function_types import FunctionParent
@ -34,6 +35,16 @@ def __getattr__(name: str) -> Any:
raise AttributeError(msg)
@dataclass(frozen=True)
class IndexResult:
file_path: Path
cached: bool
num_edges: int
edges: tuple[tuple[str, str, bool], ...] # (caller_qn, callee_name, is_cross_file)
cross_file_edges: int
error: bool
@dataclass
class HelperFunction:
"""A helper function that is a dependency of the target function.
@ -192,6 +203,35 @@ class ReferenceInfo:
caller_function: str | None = None
@runtime_checkable
class DependencyResolver(Protocol):
"""Protocol for language-specific dependency resolution.
Implementations analyze source files to discover call-graph edges
between functions so the optimizer can extract richer context.
"""
def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None:
"""Pre-index a batch of files."""
...
def get_callees(
self, file_path_to_qualified_names: dict[Path, set[str]]
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
"""Return callees for the given functions."""
...
def count_callees_per_function(
self, file_path_to_qualified_names: dict[Path, set[str]]
) -> dict[tuple[Path, str], int]:
"""Return the number of callees for each (file_path, qualified_name) pair."""
...
def close(self) -> None:
"""Release resources (e.g. database connections)."""
...
@runtime_checkable
class LanguageSupport(Protocol):
"""Protocol defining what a language implementation must provide.
@ -254,6 +294,14 @@ class LanguageSupport(Protocol):
"""Like # or //."""
...
@property
def dir_excludes(self) -> frozenset[str]:
"""Directory name patterns to skip during file discovery.
Supports glob wildcards: "name" for exact, "prefix*" for startswith, "*suffix" for endswith.
"""
...
# === Discovery ===
def discover_functions(
@ -490,6 +538,87 @@ class LanguageSupport(Protocol):
"""
...
def postprocess_generated_tests(
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
) -> GeneratedTestsList:
"""Apply language-specific postprocessing to generated tests.
Args:
generated_tests: Generated tests to update.
test_framework: Test framework used for the project.
project_root: Project root directory.
source_file_path: Path to the source file under optimization.
Returns:
Updated generated tests.
"""
...
def remove_test_functions_from_generated_tests(
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
) -> GeneratedTestsList:
"""Remove specific test functions from generated tests.
Args:
generated_tests: Generated tests to update.
functions_to_remove: List of function names to remove.
Returns:
Updated generated tests.
"""
...
def add_runtime_comments_to_generated_tests(
self,
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
tests_project_rootdir: Path | None = None,
) -> GeneratedTestsList:
"""Add runtime comments to generated tests.
Args:
generated_tests: Generated tests to update.
original_runtimes: Mapping of invocation IDs to original runtimes.
optimized_runtimes: Mapping of invocation IDs to optimized runtimes.
tests_project_rootdir: Root directory for tests (if applicable).
Returns:
Updated generated tests.
"""
...
def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str:
"""Add new global declarations from optimized code to original source.
Args:
optimized_code: The optimized code that may contain new declarations.
original_source: The original source code.
module_abspath: Path to the module file (for parser selection).
Returns:
Original source with new declarations added.
"""
...
def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is.
Returns:
Source code of the function, or None if not found.
"""
...
# === Test Result Comparison ===
def compare_test_results(
@ -556,6 +685,15 @@ class LanguageSupport(Protocol):
# Default implementation: just copy runtime files
return False
def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None:
"""Create a language-specific dependency resolver, if available.
Returns:
A DependencyResolver instance, or None if not supported.
"""
return None
def instrument_existing_test(
self,
test_path: Path,

View file

@ -0,0 +1,217 @@
"""JavaScript/TypeScript code replacement helpers."""
from __future__ import annotations
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
if TYPE_CHECKING:
from pathlib import Path
from codeflash.languages.base import Language
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
# Author: ali <mohammed18200118@gmail.com>
def _add_global_declarations_for_language(
optimized_code: str, original_source: str, module_abspath: Path, language: Language
) -> str:
"""Add new global declarations from optimized code to original source.
Finds module-level declarations (const, let, var, class, type, interface, enum)
in the optimized code that don't exist in the original source and adds them.
New declarations are inserted after any existing declarations they depend on.
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
is already declared in the original source, `_has` will be inserted after `FOO`.
Args:
optimized_code: The optimized code that may contain new declarations.
original_source: The original source code.
module_abspath: Path to the module file (for parser selection).
language: The language of the code.
Returns:
Original source with new declarations added in dependency order.
"""
from codeflash.languages.base import Language
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
return original_source
try:
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(module_abspath)
original_declarations = analyzer.find_module_level_declarations(original_source)
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
if not optimized_declarations:
return original_source
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
if not new_declarations:
return original_source
# Build a map of existing declaration names to their end lines (1-indexed)
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
# Insert each new declaration after its dependencies
result = original_source
for decl in new_declarations:
result = _insert_declaration_after_dependencies(
result, decl, existing_decl_end_lines, analyzer, module_abspath
)
# Update the map with the newly inserted declaration for subsequent insertions
# Re-parse to get accurate line numbers after insertion
updated_declarations = analyzer.find_module_level_declarations(result)
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
return result
except Exception as e:
logger.debug(f"Error adding global declarations: {e}")
return original_source
# Author: ali <mohammed18200118@gmail.com>
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
"""Get all names that already exist in the original source (declarations + imports)."""
existing_names = {decl.name for decl in original_declarations}
original_imports = analyzer.find_imports(original_source)
for imp in original_imports:
if imp.default_import:
existing_names.add(imp.default_import)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
if imp.namespace_import:
existing_names.add(imp.namespace_import)
return existing_names
# Author: ali <mohammed18200118@gmail.com>
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
"""Filter declarations to only those that don't exist in the original source."""
new_declarations = []
seen_sources: set[str] = set()
# Sort by line number to maintain order from optimized code
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
for decl in sorted_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)
return new_declarations
# Author: ali <mohammed18200118@gmail.com>
def _insert_declaration_after_dependencies(
source: str,
declaration,
existing_decl_end_lines: dict[str, int],
analyzer: TreeSitterAnalyzer,
module_abspath: Path,
) -> str:
"""Insert a declaration after the last existing declaration it depends on.
Args:
source: Current source code.
declaration: The declaration to insert.
existing_decl_end_lines: Map of existing declaration names to their end lines.
analyzer: TreeSitter analyzer.
module_abspath: Path to the module file.
Returns:
Source code with the declaration inserted at the correct position.
"""
# Find identifiers referenced in this declaration
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
# Find the latest end line among all referenced declarations
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
lines = source.splitlines(keepends=True)
# Ensure proper spacing
decl_code = declaration.source_code
if not decl_code.endswith("\n"):
decl_code += "\n"
# Add blank line before if inserting after content
if insertion_line > 0 and lines[insertion_line - 1].strip():
decl_code = "\n" + decl_code
before = lines[:insertion_line]
after = lines[insertion_line:]
return "".join([*before, decl_code, *after])
# Author: ali <mohammed18200118@gmail.com>
def _find_insertion_line_for_declaration(
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
) -> int:
"""Find the line where a declaration should be inserted based on its dependencies.
Args:
source: Source code.
referenced_names: Names referenced by the declaration.
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
analyzer: TreeSitter analyzer.
Returns:
Line index (0-based) where the declaration should be inserted.
"""
# Find the maximum end line among referenced declarations
max_dependency_line = 0
for name in referenced_names:
if name in existing_decl_end_lines:
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
if max_dependency_line > 0:
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
return max_dependency_line
# No dependencies found - insert after imports
lines = source.splitlines(keepends=True)
return _find_line_after_imports(lines, analyzer, source)
# Author: ali <mohammed18200118@gmail.com>
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index after all imports.
Args:
lines: Source lines.
analyzer: TreeSitter analyzer.
source: Full source code.
Returns:
Line index (0-based) for insertion after imports.
"""
try:
imports = analyzer.find_imports(source)
if imports:
return max(imp.end_line for imp in imports)
except Exception as exc:
logger.debug(f"Exception in _find_line_after_imports: {exc}")
# Default: insert at beginning (after shebang/directive comments)
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
return i
return 0

View file

@ -6,29 +6,13 @@ including adding runtime comments and removing test functions.
from __future__ import annotations
import os
import re
from pathlib import Path
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.result.critic import performance_gain
def format_runtime_comment(original_time: int, optimized_time: int) -> str:
"""Format a runtime comparison comment for JavaScript.
Args:
original_time: Original runtime in nanoseconds.
optimized_time: Optimized runtime in nanoseconds.
Returns:
Formatted comment string with // prefix.
"""
perf_gain = format_perf(
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
)
status = "slower" if optimized_time > original_time else "faster"
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
from codeflash.code_utils.time_utils import format_runtime_comment
from codeflash.models.models import GeneratedTests, GeneratedTestsList
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
@ -117,7 +101,7 @@ def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimiz
# Only add comment if line has a function call and doesn't already have a comment
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
comment = format_runtime_comment(orig_time, opt_time)
comment = format_runtime_comment(orig_time, opt_time, comment_prefix="//")
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
# Add comment at end of line
line = f"{line.rstrip()} {comment}"
@ -130,6 +114,165 @@ def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimiz
return "\n".join(modified_lines)
JS_TEST_EXTENSIONS = (
".test.ts",
".test.js",
".test.tsx",
".test.jsx",
".spec.ts",
".spec.js",
".spec.tsx",
".spec.jsx",
".ts",
".js",
".tsx",
".jsx",
".mjs",
".mts",
)
# TODO:{self} Needs cleanup for jest logic in else block
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
def is_js_test_module_path(test_module_path: str) -> bool:
"""Return True when the module path looks like a JS/TS test path."""
return any(test_module_path.endswith(ext) for ext in JS_TEST_EXTENSIONS)
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
def resolve_js_test_module_path(test_module_path: str, tests_project_rootdir: Path) -> Path:
"""Resolve a JS/TS test module path to a concrete file path."""
if "/" in test_module_path or "\\" in test_module_path:
return tests_project_rootdir / Path(test_module_path)
matched_ext = None
for ext in JS_TEST_EXTENSIONS:
if test_module_path.endswith(ext):
matched_ext = ext
break
if matched_ext:
base_path = test_module_path[: -len(matched_ext)]
file_path = base_path.replace(".", os.sep) + matched_ext
tests_dir_name = tests_project_rootdir.name
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
return tests_project_rootdir.parent / Path(file_path)
return tests_project_rootdir / Path(file_path)
return tests_project_rootdir / Path(test_module_path)
# Patterns for normalizing codeflash imports (legacy -> npm package)
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
_CODEFLASH_REQUIRE_PATTERN = re.compile(
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
)
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
def normalize_codeflash_imports(source: str) -> str:
"""Normalize codeflash imports to use the npm package.
Replaces legacy local file imports:
const codeflash = require('./codeflash-jest-helper')
import codeflash from './codeflash-jest-helper'
With npm package imports:
const codeflash = require('codeflash')
Args:
source: JavaScript/TypeScript source code.
Returns:
Source code with normalized imports.
"""
# Replace CommonJS require
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
# Replace ES module import
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
# Author: ali <mohammed18200118@gmail.com>
def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList:
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
"""Inject test globals into all generated tests.
Args:
generated_tests: List of generated tests.
test_framework: The test framework being used ("jest", "vitest", or "mocha").
Returns:
Generated tests with test globals injected.
"""
# we only inject test globals for esm modules
# Use vitest imports for vitest projects, jest imports for jest projects
if test_framework == "vitest":
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
else:
# Default to jest imports for jest and other frameworks
global_import = (
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
)
for test in generated_tests.generated_tests:
test.generated_original_test_source = global_import + test.generated_original_test_source
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
return generated_tests
# Author: ali <mohammed18200118@gmail.com>
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Disable TypeScript type checking in all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with TypeScript type checking disabled.
"""
# we only inject test globals for esm modules
ts_nocheck = "// @ts-nocheck\n"
for test in generated_tests.generated_tests:
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
return generated_tests
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Normalize codeflash imports in all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with normalized imports.
"""
normalized_tests = []
for test in generated_tests.generated_tests:
# Only normalize JS/TS files
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
normalized_test = GeneratedTests(
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
normalized_tests.append(normalized_test)
else:
normalized_tests.append(test)
return GeneratedTestsList(generated_tests=normalized_tests)
def remove_test_functions(source: str, functions_to_remove: list[str]) -> str:
"""Remove specific test functions from JavaScript test source code.

View file

@ -44,8 +44,7 @@ class ImportResolver:
project_root: Root directory of the project.
"""
# Resolve to real path to handle macOS symlinks like /var -> /private/var
self.project_root = project_root.resolve()
self.project_root = project_root
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:

View file

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from codeflash.languages.base import ReferenceInfo
from codeflash.languages.javascript.treesitter import TypeDefinition
from codeflash.models.models import GeneratedTestsList, InvocationId
logger = logging.getLogger(__name__)
@ -63,6 +64,10 @@ class JavaScriptSupport:
def comment_prefix(self) -> str:
return "//"
@property
def dir_excludes(self) -> frozenset[str]:
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
# === Discovery ===
def discover_functions(
@ -1774,6 +1779,116 @@ class JavaScriptSupport:
return remove_test_functions(test_source, functions_to_remove)
def postprocess_generated_tests(
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
) -> GeneratedTestsList:
"""Apply language-specific postprocessing to generated tests."""
from codeflash.languages.javascript.edit_tests import (
disable_ts_check,
inject_test_globals,
normalize_generated_tests_imports,
)
from codeflash.languages.javascript.module_system import detect_module_system
module_system = detect_module_system(project_root, source_file_path)
if module_system == "esm":
generated_tests = inject_test_globals(generated_tests, test_framework)
if self.language == Language.TYPESCRIPT:
generated_tests = disable_ts_check(generated_tests)
return normalize_generated_tests_imports(generated_tests)
def remove_test_functions_from_generated_tests(
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
) -> GeneratedTestsList:
"""Remove specific test functions from generated tests."""
from codeflash.models.models import GeneratedTests, GeneratedTestsList
updated_tests: list[GeneratedTests] = []
for test in generated_tests.generated_tests:
updated_tests.append(
GeneratedTests(
generated_original_test_source=self.remove_test_functions(
test.generated_original_test_source, functions_to_remove
),
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
)
return GeneratedTestsList(generated_tests=updated_tests)
def add_runtime_comments_to_generated_tests(
self,
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
tests_project_rootdir: Path | None = None,
) -> GeneratedTestsList:
"""Add runtime comments to generated tests."""
from codeflash.models.models import GeneratedTests, GeneratedTestsList
tests_root = tests_project_rootdir or Path()
original_runtimes_dict = self._build_runtime_map(original_runtimes, tests_root)
optimized_runtimes_dict = self._build_runtime_map(optimized_runtimes, tests_root)
modified_tests: list[GeneratedTests] = []
for test in generated_tests.generated_tests:
modified_source = self.add_runtime_comments(
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
)
modified_tests.append(
GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
)
return GeneratedTestsList(generated_tests=modified_tests)
def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str:
from codeflash.languages.javascript.code_replacer import _add_global_declarations_for_language
return _add_global_declarations_for_language(optimized_code, original_source, module_abspath, self.language)
def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None:
from codeflash.languages.javascript.treesitter import extract_calling_function_source
return extract_calling_function_source(source_code, function_name, ref_line)
def _build_runtime_map(
self, inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path
) -> dict[str, int]:
from codeflash.languages.javascript.edit_tests import resolve_js_test_module_path
unique_inv_ids: dict[str, int] = {}
for inv_id, runtimes in inv_id_runtimes.items():
test_qualified_name = (
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
if inv_id.test_class_name
else inv_id.test_function_name
)
if not test_qualified_name:
continue
abs_path = resolve_js_test_module_path(inv_id.test_module_path, tests_project_rootdir)
abs_path_str = str(abs_path.resolve().with_suffix(""))
if "__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str:
continue
key = test_qualified_name + "#" + abs_path_str
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
cur_invid = (
inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1])
) # type: ignore[union-attr]
match_key = key + "#" + cur_invid
if match_key not in unique_inv_ids:
unique_inv_ids[match_key] = 0
unique_inv_ids[match_key] += min(runtimes)
return unique_inv_ids
# === Test Result Comparison ===
def compare_test_results(
@ -2000,6 +2115,9 @@ class JavaScriptSupport:
logger.error("Could not install codeflash. Please run: npm install --save-dev codeflash")
return False
def create_dependency_resolver(self, project_root: Path) -> None:
return None
def instrument_existing_test(
self,
test_path: Path,

View file

@ -1788,3 +1788,39 @@ def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer:
return TreeSitterAnalyzer(TreeSitterLanguage.TSX)
# Default to JavaScript for .js, .jsx, .mjs, .cjs
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
# Author: Saurabh Misra <misra.saurabh1@gmail.com>
def extract_calling_function_source(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in JavaScript/TypeScript.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is (helps identify the right function).
Returns:
Source code of the function, or None if not found.
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
# Try TypeScript first, fall back to JavaScript
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
try:
analyzer = TreeSitterAnalyzer(lang)
functions = analyzer.find_functions(source_code, include_methods=True)
for func in functions:
if func.name == function_name:
# Check if the reference line is within this function
if func.start_line <= ref_line <= func.end_line:
return func.source_text
break
except Exception:
continue
return None
except Exception:
return None

View file

@ -5,6 +5,7 @@ Python-specific implementations (LibCST, Jedi, pytest, etc.) to conform
to the LanguageSupport protocol.
"""
from codeflash.languages.python.reference_graph import ReferenceGraph
from codeflash.languages.python.support import PythonSupport
__all__ = ["PythonSupport"]
__all__ = ["PythonSupport", "ReferenceGraph"]

View file

@ -10,7 +10,6 @@ from typing import TYPE_CHECKING
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
@ -24,6 +23,10 @@ from codeflash.languages.python.context.unused_definition_remover import (
recurse_sections,
remove_unused_definitions_by_function_names,
)
from codeflash.languages.python.static_analysis.code_extractor import (
add_needed_imports_from_module,
find_preexisting_objects,
)
from codeflash.models.models import (
CodeContextType,
CodeOptimizationContext,
@ -38,7 +41,7 @@ if TYPE_CHECKING:
from jedi.api.classes import Name
from codeflash.languages.base import HelperFunction
from codeflash.languages.base import DependencyResolver, HelperFunction
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
# Error message constants
@ -87,6 +90,7 @@ def get_code_optimization_context(
project_root_path: Path,
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
call_graph: DependencyResolver | None = None,
) -> CodeOptimizationContext:
# Route to language-specific implementation for non-Python languages
if not is_python():
@ -95,9 +99,11 @@ def get_code_optimization_context(
)
# Get FunctionSource representation of helpers of FTO
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
)
fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}}
if call_graph is not None:
helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input)
else:
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path)
# Add function to optimize into helpers of FTO dict, as they'll be processed together
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
@ -113,8 +119,7 @@ def get_code_optimization_context(
for qualified_names in helpers_of_fto_qualified_names_dict.values():
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
# Get FunctionSource representation of helpers of helpers of FTO
helpers_of_helpers_dict, _helpers_of_helpers_list = get_function_sources_from_jedi(
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
helpers_of_fto_qualified_names_dict, project_root_path
)
@ -198,6 +203,8 @@ def get_code_optimization_context(
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list})
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=final_read_writable_code,
@ -205,6 +212,7 @@ def get_code_optimization_context(
hashing_code_context=code_hash_context,
hashing_code_context_hash=code_hash,
helper_functions=helpers_of_fto_list,
testgen_helper_fqns=all_helper_fqns,
preexisting_objects=preexisting_objects,
)
@ -266,7 +274,6 @@ def get_code_optimization_context_for_language(
fully_qualified_name=helper.qualified_name,
only_function_name=helper.name,
source_code=helper.source_code,
jedi_definition=None,
)
)
@ -335,13 +342,12 @@ def get_code_optimization_context_for_language(
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=read_writable_code,
# Pass type definitions and globals as read-only context for the AI
# This way the AI sees them as context but doesn't include them in optimized output
read_only_context_code=code_context.read_only_context,
hashing_code_context=read_writable_code.flat,
hashing_code_context_hash=code_hash,
helper_functions=helper_function_sources,
preexisting_objects=set(), # Not implemented for non-Python yet
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
preexisting_objects=set(),
)
@ -496,7 +502,6 @@ def get_function_to_optimize_as_function_source(
fully_qualified_name=name.full_name,
only_function_name=name.name,
source_code=name.get_line_code(),
jedi_definition=name,
)
except Exception as e:
logger.exception(f"Error while getting function source: {e}")
@ -533,6 +538,10 @@ def get_function_sources_from_jedi(
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
if definition_path is not None:
rel = safe_relative_to(definition_path, project_root_path)
if not rel.is_absolute():
definition_path = project_root_path / rel
# The definition is part of this project and not defined within the original function
is_valid_definition = (
@ -543,15 +552,16 @@ def get_function_sources_from_jedi(
and not belongs_to_function_qualified(definition, qualified_function_name)
and definition.full_name.startswith(definition.module_name)
)
if is_valid_definition and definition.type in ("function", "class"):
if is_valid_definition and definition.type in ("function", "class", "statement"):
if definition.type == "function":
fqn = definition.full_name
func_name = definition.name
else:
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
# This ensures the class definition with constructor is included in testgen context
elif definition.type == "class":
fqn = f"{definition.full_name}.__init__"
func_name = "__init__"
else:
fqn = definition.full_name
func_name = definition.name
qualified_name = get_qualified_name(definition.module_name, fqn)
# Avoid nested functions or classes. Only class.function is allowed
if len(qualified_name.split(".")) <= 2:
@ -561,7 +571,6 @@ def get_function_sources_from_jedi(
fully_qualified_name=fqn,
only_function_name=func_name,
source_code=definition.get_line_code(),
jedi_definition=definition,
)
file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source)
@ -1006,6 +1015,255 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
return CodeStringsMarkdown(code_strings=code_strings)
def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]:
"""Import modules and resolve candidate (class_name, module_name) pairs to class objects."""
import importlib
import inspect
resolved: list[tuple[type, str]] = []
module_cache: dict[str, object] = {}
for class_name, module_name in candidates:
try:
module = module_cache.get(module_name)
if module is None:
module = importlib.import_module(module_name)
module_cache[module_name] = module
cls = getattr(module, class_name, None)
if cls is not None and inspect.isclass(cls):
resolved.append((cls, class_name))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to import {module_name}.{class_name}")
return resolved
MAX_TRANSITIVE_DEPTH = 5
def extract_classes_from_type_hint(hint: object) -> list[type]:
"""Recursively extract concrete class objects from a type annotation.
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
Filters out builtins and typing module types.
"""
import typing
classes: list[type] = []
origin = getattr(hint, "__origin__", None)
args = getattr(hint, "__args__", None)
if origin is not None and args:
for arg in args:
classes.extend(extract_classes_from_type_hint(arg))
elif isinstance(hint, type):
module = getattr(hint, "__module__", "")
if module not in ("builtins", "typing", "typing_extensions", "types"):
classes.append(hint)
# Handle typing.Annotated on older Pythons where __origin__ may not be set
if hasattr(typing, "get_args") and origin is None and args is None:
try:
inner_args = typing.get_args(hint)
if inner_args:
for arg in inner_args:
classes.extend(extract_classes_from_type_hint(arg))
except Exception:
pass
return classes
def resolve_transitive_type_deps(cls: type) -> list[type]:
"""Find external classes referenced in cls.__init__ type annotations.
Returns classes from site-packages that have a custom __init__.
"""
import inspect
import typing
try:
init_method = getattr(cls, "__init__")
hints = typing.get_type_hints(init_method)
except Exception:
return []
deps: list[type] = []
for param_name, hint in hints.items():
if param_name == "return":
continue
for dep_cls in extract_classes_from_type_hint(hint):
if dep_cls is cls:
continue
init_method = getattr(dep_cls, "__init__", None)
if init_method is None or init_method is object.__init__:
continue
try:
class_file = Path(inspect.getfile(dep_cls))
except (OSError, TypeError):
continue
if not path_belongs_to_site_packages(class_file):
continue
deps.append(dep_cls)
return deps
def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None:
"""Extract a stub containing the class definition with only its __init__ method.
Args:
cls: The class object to extract __init__ from
class_name: Name to use for the class in the stub
require_site_packages: If True, only extract from site-packages. If False, include stdlib too.
"""
import inspect
import textwrap
init_method = getattr(cls, "__init__", None)
if init_method is None or init_method is object.__init__:
return None
try:
class_file = Path(inspect.getfile(cls))
except (OSError, TypeError):
return None
if require_site_packages and not path_belongs_to_site_packages(class_file):
return None
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
except (OSError, TypeError):
return None
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
return CodeString(code=class_source, file_path=class_file)
def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool:
cached = cache.get(module_name)
if cached is not None:
return cached
is_project = _is_project_module(module_name, project_root_path)
cache[module_name] = is_project
return is_project
def is_project_path(module_path: Path | None, project_root_path: Path) -> bool:
if module_path is None:
return False
# site-packages must be checked first because .venv/site-packages is under project root
if path_belongs_to_site_packages(module_path):
return False
try:
module_path.resolve().relative_to(project_root_path.resolve())
return True
except ValueError:
return False
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
"""Check if a module is part of the project (not external/stdlib)."""
import importlib.util
try:
spec = importlib.util.find_spec(module_name)
except (ImportError, ModuleNotFoundError, ValueError):
return False
else:
if spec is None or spec.origin is None:
return False
return is_project_path(Path(spec.origin), project_root_path)
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
"""Extract import statements needed for a class definition.
This extracts imports for base classes, decorators, and type annotations.
"""
needed_names: set[str] = set()
# Get base class names
for base in class_node.bases:
if isinstance(base, ast.Name):
needed_names.add(base.id)
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
# For things like abc.ABC, we need the module name
needed_names.add(base.value.id)
# Get decorator names (e.g., dataclass, field)
for decorator in class_node.decorator_list:
if isinstance(decorator, ast.Name):
needed_names.add(decorator.id)
elif isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Name):
needed_names.add(decorator.func.id)
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
needed_names.add(decorator.func.value.id)
# Get type annotation names from class body (for dataclass fields)
for item in class_node.body:
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
# Also check for field() calls which are common in dataclasses
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
if isinstance(item.value.func, ast.Name):
needed_names.add(item.value.func.id)
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")
added_imports: set[int] = set() # Track line numbers to avoid duplicates
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
return "\n".join(import_lines)
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
"""Recursively collect type annotation names from an AST node."""
if isinstance(node, ast.Name):
names.add(node.id)
elif isinstance(node, ast.Subscript):
collect_names_from_annotation(node.value, names)
collect_names_from_annotation(node.slice, names)
elif isinstance(node, ast.Tuple):
for elt in node.elts:
collect_names_from_annotation(elt, names)
elif isinstance(node, ast.BinOp): # For Union types with | syntax
collect_names_from_annotation(node.left, names)
collect_names_from_annotation(node.right, names)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)
def is_dunder_method(name: str) -> bool:
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists."""
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):

View file

@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Union
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.languages import is_javascript
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeString, CodeStringsMarkdown
if TYPE_CHECKING:
@ -587,17 +587,20 @@ def revert_unused_helper_functions(
logger.debug(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
# Resolve all path keys for consistent comparison (Windows 8.3 short names may differ from Jedi-resolved paths)
resolved_original_helper_code = {p.resolve(): code for p, code in original_helper_code.items()}
# Group unused helpers by file path
unused_helpers_by_file = defaultdict(list)
for helper in unused_helpers:
unused_helpers_by_file[helper.file_path].append(helper)
unused_helpers_by_file[helper.file_path.resolve()].append(helper)
# For each file, revert the unused helper functions to their original definitions
for file_path, helpers_in_file in unused_helpers_by_file.items():
if file_path in original_helper_code:
if file_path in resolved_original_helper_code:
try:
# Get original code for this file
original_code = original_helper_code[file_path]
original_code = resolved_original_helper_code[file_path]
# Use the code replacer to selectively revert only the unused helper functions
helper_names = [helper.qualified_name for helper in helpers_in_file]
@ -640,15 +643,31 @@ def _analyze_imports_in_optimized_code(
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type if helper.jedi_definition else None
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
jedi_type = helper.definition_type
if jedi_type != "class": # Include when definition_type is None (non-Python)
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper)
helpers_by_file[module_name].append(helper)
for node in ast.walk(optimized_ast):
# Collect only import nodes to avoid per-node isinstance checks across the whole AST
class _ImportCollector(ast.NodeVisitor):
def __init__(self) -> None:
self.nodes: list[ast.AST] = []
def visit_Import(self, node: ast.Import) -> None:
self.nodes.append(node)
# No need to recurse further for import nodes
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
self.nodes.append(node)
# No need to recurse further for import-from nodes
collector = _ImportCollector()
collector.visit(optimized_ast)
for node in collector.nodes:
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
module_name = node.module
@ -728,8 +747,8 @@ def detect_unused_helper_functions(
"""
# Skip this analysis for non-Python languages since we use Python's ast module
if is_javascript():
logger.debug("Skipping unused helper function detection for JavaScript/TypeScript")
if not is_python():
logger.debug("Skipping unused helper function detection for non-Python languages")
return []
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
@ -799,8 +818,8 @@ def detect_unused_helper_functions(
unused_helpers = []
entrypoint_file_path = function_to_optimize.file_path
for helper_function in code_context.helper_functions:
jedi_type = helper_function.jedi_definition.type if helper_function.jedi_definition else None
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
jedi_type = helper_function.definition_type
if jedi_type != "class": # Include when definition_type is None (non-Python)
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name

View file

@ -0,0 +1,544 @@
from __future__ import annotations
import hashlib
import os
import sqlite3
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.languages.base import IndexResult
from codeflash.models.models import FunctionSource
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from jedi.api.classes import Name
# ---------------------------------------------------------------------------
# Module-level helpers (must be top-level for ProcessPoolExecutor pickling)
# ---------------------------------------------------------------------------
# TODO: create call graph.
_PARALLEL_THRESHOLD = 8
# Per-worker state, initialised by _init_index_worker in child processes
_worker_jedi_project: object | None = None
_worker_project_root_str: str | None = None
def _init_index_worker(project_root: str) -> None:
import jedi
global _worker_jedi_project, _worker_project_root_str
_worker_jedi_project = jedi.Project(path=project_root)
_worker_project_root_str = project_root
def _resolve_definitions(ref: Name) -> list[Name]:
try:
inferred = ref.infer()
valid = [d for d in inferred if d.type in ("function", "class")]
if valid:
return valid
except Exception:
pass
try:
result: list[Name] = ref.goto(follow_imports=True, follow_builtin_imports=False)
return result
except Exception:
return []
def _is_valid_definition(definition: Name, caller_qualified_name: str, project_root_str: str) -> bool:
definition_path = definition.module_path
if definition_path is None:
return False
if not str(definition_path).startswith(project_root_str + os.sep):
return False
if path_belongs_to_site_packages(definition_path):
return False
if not definition.full_name or not definition.full_name.startswith(definition.module_name):
return False
if definition.type not in ("function", "class"):
return False
try:
def_qn = get_qualified_name(definition.module_name, definition.full_name)
if def_qn == caller_qualified_name:
return False
except ValueError:
return False
try:
from codeflash.optimization.function_context import belongs_to_function_qualified
if belongs_to_function_qualified(definition, caller_qualified_name):
return False
except Exception:
pass
return True
def _get_enclosing_function_qn(ref: Name) -> str | None:
try:
parent = ref.parent()
if parent is None or parent.type != "function":
return None
if not parent.full_name or not parent.full_name.startswith(parent.module_name):
return None
return get_qualified_name(parent.module_name, parent.full_name)
except (ValueError, AttributeError):
return None
def _analyze_file(file_path: Path, jedi_project: object, project_root_str: str) -> tuple[set[tuple[str, ...]], bool]:
"""Pure Jedi analysis — no DB access. Returns (edges, had_error)."""
import jedi
resolved = str(file_path.resolve())
try:
script = jedi.Script(path=file_path, project=jedi_project)
refs = script.get_names(all_scopes=True, definitions=False, references=True)
except Exception:
return set(), True
edges: set[tuple[str, ...]] = set()
for ref in refs:
try:
caller_qn = _get_enclosing_function_qn(ref)
if caller_qn is None:
continue
definitions = _resolve_definitions(ref)
if not definitions:
continue
definition = definitions[0]
definition_path = definition.module_path
if definition_path is None:
continue
if not _is_valid_definition(definition, caller_qn, project_root_str):
continue
edge_base = (resolved, caller_qn, str(definition_path))
if definition.type == "function":
callee_qn = get_qualified_name(definition.module_name, definition.full_name)
if len(callee_qn.split(".")) > 2:
continue
edges.add(
(
*edge_base,
callee_qn,
definition.full_name,
definition.name,
definition.type,
definition.get_line_code(),
)
)
elif definition.type == "class":
init_qn = get_qualified_name(definition.module_name, f"{definition.full_name}.__init__")
if len(init_qn.split(".")) > 2:
continue
edges.add(
(
*edge_base,
init_qn,
f"{definition.full_name}.__init__",
"__init__",
definition.type,
definition.get_line_code(),
)
)
except Exception:
continue
return edges, False
def _index_file_worker(args: tuple[str, str]) -> tuple[str, str, set[tuple[str, ...]], bool]:
"""Worker entry point for ProcessPoolExecutor."""
file_path_str, file_hash = args
assert _worker_project_root_str is not None
edges, had_error = _analyze_file(Path(file_path_str), _worker_jedi_project, _worker_project_root_str)
return file_path_str, file_hash, edges, had_error
# ---------------------------------------------------------------------------
class ReferenceGraph:
SCHEMA_VERSION = 2
def __init__(self, project_root: Path, language: str = "python", db_path: Path | None = None) -> None:
import jedi
self.project_root = project_root.resolve()
self.project_root_str = str(self.project_root)
self.language = language
self.jedi_project = jedi.Project(path=self.project_root)
if db_path is None:
from codeflash.code_utils.compat import codeflash_cache_db
db_path = codeflash_cache_db
self.conn = sqlite3.connect(str(db_path))
self.conn.execute("PRAGMA journal_mode=WAL")
self.indexed_file_hashes: dict[str, str] = {}
self._init_schema()
def _init_schema(self) -> None:
cur = self.conn.cursor()
cur.execute("CREATE TABLE IF NOT EXISTS cg_schema_version (version INTEGER PRIMARY KEY)")
row = cur.execute("SELECT version FROM cg_schema_version LIMIT 1").fetchone()
if row is None:
cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
elif row[0] != self.SCHEMA_VERSION:
for table in [
"cg_call_edges",
"cg_indexed_files",
"cg_languages",
"cg_projects",
"cg_project_meta",
"indexed_files",
"call_edges",
]:
cur.execute(f"DROP TABLE IF EXISTS {table}")
cur.execute("DELETE FROM cg_schema_version")
cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
cur.execute(
"""
CREATE TABLE IF NOT EXISTS indexed_files (
project_root TEXT NOT NULL,
language TEXT NOT NULL,
file_path TEXT NOT NULL,
file_hash TEXT NOT NULL,
PRIMARY KEY (project_root, language, file_path)
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS call_edges (
project_root TEXT NOT NULL,
language TEXT NOT NULL,
caller_file TEXT NOT NULL,
caller_qualified_name TEXT NOT NULL,
callee_file TEXT NOT NULL,
callee_qualified_name TEXT NOT NULL,
callee_fully_qualified_name TEXT NOT NULL,
callee_only_function_name TEXT NOT NULL,
callee_definition_type TEXT NOT NULL,
callee_source_line TEXT NOT NULL,
PRIMARY KEY (project_root, language, caller_file, caller_qualified_name,
callee_file, callee_qualified_name)
)
"""
)
cur.execute(
"""
CREATE INDEX IF NOT EXISTS idx_call_edges_caller
ON call_edges (project_root, language, caller_file, caller_qualified_name)
"""
)
self.conn.commit()
def get_callees(
self, file_path_to_qualified_names: dict[Path, set[str]]
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set)
function_source_list: list[FunctionSource] = []
all_caller_keys: list[tuple[str, str]] = []
for file_path, qualified_names in file_path_to_qualified_names.items():
resolved = str(file_path.resolve())
self.ensure_file_indexed(file_path, resolved)
all_caller_keys.extend((resolved, qn) for qn in qualified_names)
if not all_caller_keys:
return file_path_to_function_source, function_source_list
cur = self.conn.cursor()
cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)")
cur.execute("DELETE FROM _caller_keys")
cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys)
rows = cur.execute(
"""
SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name,
ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line
FROM call_edges ce
INNER JOIN _caller_keys ck
ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name
WHERE ce.project_root = ? AND ce.language = ?
""",
(self.project_root_str, self.language),
).fetchall()
for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows:
callee_path = Path(callee_file)
fs = FunctionSource(
file_path=callee_path,
qualified_name=callee_qn,
fully_qualified_name=callee_fqn,
only_function_name=callee_name,
source_code=callee_src,
definition_type=callee_type,
)
file_path_to_function_source[callee_path].add(fs)
function_source_list.append(fs)
return file_path_to_function_source, function_source_list
def count_callees_per_function(
self, file_path_to_qualified_names: dict[Path, set[str]]
) -> dict[tuple[Path, str], int]:
all_caller_keys: list[tuple[Path, str, str]] = []
for file_path, qualified_names in file_path_to_qualified_names.items():
resolved = str(file_path.resolve())
self.ensure_file_indexed(file_path, resolved)
all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names)
if not all_caller_keys:
return {}
cur = self.conn.cursor()
cur.execute("CREATE TEMP TABLE IF NOT EXISTS _count_keys (caller_file TEXT, caller_qualified_name TEXT)")
cur.execute("DELETE FROM _count_keys")
cur.executemany(
"INSERT INTO _count_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys]
)
rows = cur.execute(
"""
SELECT ck.caller_file, ck.caller_qualified_name, COUNT(ce.rowid)
FROM _count_keys ck
LEFT JOIN call_edges ce
ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name
AND ce.project_root = ? AND ce.language = ?
GROUP BY ck.caller_file, ck.caller_qualified_name
""",
(self.project_root_str, self.language),
).fetchall()
resolved_to_path: dict[str, Path] = {resolved: fp for fp, resolved, _ in all_caller_keys}
counts: dict[tuple[Path, str], int] = {}
for caller_file, caller_qn, cnt in rows:
counts[(resolved_to_path[caller_file], caller_qn)] = cnt
return counts
def ensure_file_indexed(self, file_path: Path, resolved: str | None = None) -> IndexResult:
if resolved is None:
resolved = str(file_path.resolve())
# Always read and hash the file before checking the cache so we detect on-disk changes
try:
content = file_path.read_text(encoding="utf-8")
except Exception:
return IndexResult(file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True)
file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
if self._is_file_cached(resolved, file_hash):
return IndexResult(file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False)
return self.index_file(file_path, file_hash, resolved)
def index_file(self, file_path: Path, file_hash: str, resolved: str | None = None) -> IndexResult:
if resolved is None:
resolved = str(file_path.resolve())
edges, had_error = _analyze_file(file_path, self.jedi_project, self.project_root_str)
if had_error:
logger.debug(f"ReferenceGraph: failed to parse {file_path}")
return self._persist_edges(file_path, resolved, file_hash, edges, had_error)
def _persist_edges(
self, file_path: Path, resolved: str, file_hash: str, edges: set[tuple[str, ...]], had_error: bool
) -> IndexResult:
cur = self.conn.cursor()
scope = (self.project_root_str, self.language)
# Clear existing data for this file
cur.execute(
"DELETE FROM call_edges WHERE project_root = ? AND language = ? AND caller_file = ?", (*scope, resolved)
)
cur.execute(
"DELETE FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?", (*scope, resolved)
)
# Insert new edges if parsing succeeded
if not had_error and edges:
cur.executemany(
"""
INSERT OR REPLACE INTO call_edges
(project_root, language, caller_file, caller_qualified_name,
callee_file, callee_qualified_name, callee_fully_qualified_name,
callee_only_function_name, callee_definition_type, callee_source_line)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[(*scope, *edge) for edge in edges],
)
# Record that this file has been indexed
cur.execute(
"INSERT OR REPLACE INTO indexed_files (project_root, language, file_path, file_hash) VALUES (?, ?, ?, ?)",
(*scope, resolved, file_hash),
)
self.conn.commit()
self.indexed_file_hashes[resolved] = file_hash
# Build summary for return value
edges_summary = tuple(
(caller_qn, callee_name, caller_file != callee_file)
for (caller_file, caller_qn, callee_file, _, _, callee_name, _, _) in edges
)
cross_file_count = sum(is_cross_file for _, _, is_cross_file in edges_summary)
return IndexResult(
file_path=file_path,
cached=False,
num_edges=len(edges),
edges=edges_summary,
cross_file_edges=cross_file_count,
error=had_error,
)
def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None:
"""Pre-index a batch of files, using multiprocessing for large uncached batches."""
to_index: list[tuple[Path, str, str]] = []
for file_path in file_paths:
resolved = str(file_path.resolve())
try:
content = file_path.read_text(encoding="utf-8")
except Exception:
self._report_progress(
on_progress,
IndexResult(
file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True
),
)
continue
file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
# Check if already cached (in-memory or DB)
if self._is_file_cached(resolved, file_hash):
self._report_progress(
on_progress,
IndexResult(
file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False
),
)
continue
to_index.append((file_path, resolved, file_hash))
if not to_index:
return
# Index uncached files
if len(to_index) >= _PARALLEL_THRESHOLD:
self._build_index_parallel(to_index, on_progress)
else:
for file_path, resolved, file_hash in to_index:
result = self.index_file(file_path, file_hash, resolved)
self._report_progress(on_progress, result)
def _is_file_cached(self, resolved: str, file_hash: str) -> bool:
"""Check if file is cached in memory or DB."""
if self.indexed_file_hashes.get(resolved) == file_hash:
return True
row = self.conn.execute(
"SELECT file_hash FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?",
(self.project_root_str, self.language, resolved),
).fetchone()
if row and row[0] == file_hash:
self.indexed_file_hashes[resolved] = file_hash
return True
return False
def _report_progress(self, on_progress: Callable[[IndexResult], None] | None, result: IndexResult) -> None:
"""Report progress if callback provided."""
if on_progress is not None:
on_progress(result)
def _build_index_parallel(
self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None
) -> None:
from concurrent.futures import ProcessPoolExecutor, as_completed
max_workers = min(os.cpu_count() or 1, len(to_index), 8)
path_info: dict[str, tuple[Path, str]] = {resolved: (fp, fh) for fp, resolved, fh in to_index}
worker_args = [(resolved, fh) for _fp, resolved, fh in to_index]
logger.debug(f"ReferenceGraph: indexing {len(to_index)} files across {max_workers} workers")
try:
with ProcessPoolExecutor(
max_workers=max_workers, initializer=_init_index_worker, initargs=(self.project_root_str,)
) as executor:
futures = {executor.submit(_index_file_worker, args): args[0] for args in worker_args}
for future in as_completed(futures):
resolved = futures[future]
file_path, file_hash = path_info[resolved]
try:
_, _, edges, had_error = future.result()
except Exception:
logger.debug(f"ReferenceGraph: worker failed for {file_path}")
self._persist_edges(file_path, resolved, file_hash, set(), had_error=True)
self._report_progress(
on_progress,
IndexResult(
file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True
),
)
continue
if had_error:
logger.debug(f"ReferenceGraph: failed to parse {file_path}")
result = self._persist_edges(file_path, resolved, file_hash, edges, had_error)
self._report_progress(on_progress, result)
except Exception:
logger.debug("ReferenceGraph: parallel indexing failed, falling back to sequential")
self._fallback_sequential_index(to_index, on_progress)
def _fallback_sequential_index(
self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None
) -> None:
"""Fallback to sequential indexing when parallel processing fails."""
for file_path, resolved, file_hash in to_index:
# Skip files already persisted before the failure
if resolved in self.indexed_file_hashes:
continue
result = self.index_file(file_path, file_hash, resolved)
self._report_progress(on_progress, result)
def close(self) -> None:
self.conn.close()

View file

@ -1659,6 +1659,13 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro
refs_by_file[ref.file_path] = []
refs_by_file[ref.file_path].append(ref)
from codeflash.languages.registry import get_language_support
try:
lang_support = get_language_support(language)
except Exception:
lang_support = None
fn_call_context = ""
context_len = 0
@ -1700,7 +1707,11 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro
# Extract context around the reference
if ref.caller_function:
# Try to extract the full calling function
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
func_code = None
if lang_support is not None:
func_code = lang_support.extract_calling_function_source(
file_content, ref.caller_function, ref.line
)
if func_code:
caller_contexts.append(func_code)
context_len += len(func_code)
@ -1718,77 +1729,3 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro
fn_call_context += "\n```\n"
return fn_call_context
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
"""Extract the source code of a calling function.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is.
language: The programming language.
Returns:
Source code of the function, or None if not found.
"""
if language == Language.PYTHON:
return _extract_calling_function_python(source_code, function_name, ref_line)
return _extract_calling_function_js(source_code, function_name, ref_line)
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in Python."""
try:
import ast
tree = ast.parse(source_code)
lines = source_code.splitlines()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name == function_name:
# Check if the reference line is within this function
start_line = node.lineno
end_line = node.end_lineno or start_line
if start_line <= ref_line <= end_line:
return "\n".join(lines[start_line - 1 : end_line])
return None
except Exception:
return None
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in JavaScript/TypeScript.
Args:
source_code: Full source code of the file.
function_name: Name of the function to extract.
ref_line: Line number where the reference is (helps identify the right function).
Returns:
Source code of the function, or None if not found.
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
# Try TypeScript first, fall back to JavaScript
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
try:
analyzer = TreeSitterAnalyzer(lang)
functions = analyzer.find_functions(source_code, include_methods=True)
for func in functions:
if func.name == function_name:
# Check if the reference line is within this function
if func.start_line <= ref_line <= func.end_line:
return func.source_text
break
except Exception:
continue
return None
except Exception:
return None

View file

@ -10,23 +10,22 @@ import libcst as cst
from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import (
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.code_extractor import (
add_global_assignments,
add_needed_imports_from_module,
find_insertion_index_after_imports,
)
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.line_profile_utils import ImportAdder
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
from codeflash.languages.base import LanguageSupport
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
@ -240,149 +239,6 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
test_path.write_text(modified_module.code, encoding="utf-8")
class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(
self,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
function_names: set[tuple[str | None, str]] | None = None,
) -> None:
super().__init__()
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
self.function_names = function_names # set of (class_name, function_name)
self.modified_functions: dict[
tuple[str | None, str], cst.FunctionDef
] = {} # keys are (class_name, function_name)
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.new_classes: list[cst.ClassDef] = []
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
if (self.current_class, node.name.value) in self.function_names:
self.modified_functions[(self.current_class, node.name.value)] = node
elif self.current_class and node.name.value == "__init__":
self.modified_init_functions[self.current_class] = node
elif (
self.preexisting_objects
and (node.name.value, ()) not in self.preexisting_objects
and self.current_class is None
):
self.new_functions.append(node)
return False
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if self.current_class:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
if (node.name.value, ()) not in self.preexisting_objects:
self.new_classes.append(node)
for child_node in node.body.body:
if (
self.preexisting_objects
and isinstance(child_node, cst.FunctionDef)
and (child_node.name.value, parents) not in self.preexisting_objects
):
self.new_class_functions[node.name.value].append(child_node)
return True
def leave_ClassDef(self, node: cst.ClassDef) -> None:
if self.current_class:
self.current_class = None
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
new_classes: Optional[list[cst.ClassDef]] = None,
new_functions: Optional[list[cst.FunctionDef]] = None,
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
) -> None:
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
self.new_functions = new_functions if new_functions is not None else []
self.new_classes = new_classes if new_classes is not None else []
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
self.modified_init_functions: dict[str, cst.FunctionDef] = (
modified_init_functions if modified_init_functions is not None else {}
)
self.current_class = None
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if (self.current_class, original_node.name.value) in self.modified_functions:
node = self.modified_functions[(self.current_class, original_node.name.value)]
return updated_node.with_changes(body=node.body, decorators=node.decorators)
if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions:
return self.modified_init_functions[self.current_class]
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if self.current_class:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value
return True
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
if self.current_class and self.current_class == original_node.name.value:
self.current_class = None
if original_node.name.value in self.new_class_functions:
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value]))
)
)
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
node = updated_node
max_function_index = None
max_class_index = None
for index, _node in enumerate(node.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
max_class_index = index
if self.new_classes:
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
unique_classes = [
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
]
if unique_classes:
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
new_body = list(
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
)
node = node.with_changes(body=new_body)
if max_function_index is not None:
node = node.with_changes(
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
)
elif max_class_index is not None:
node = node.with_changes(
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.new_functions, *node.body))
return node
def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
@ -401,23 +257,114 @@ def replace_functions_in_file(
return source_code
parsed_function_names.append((class_name, function_name))
# Collect functions we want to modify from the optimized code
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
# Collect functions from optimized code without using MetadataWrapper
optimized_module = cst.parse_module(optimized_code)
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = {}
new_functions: list[cst.FunctionDef] = []
new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
new_classes: list[cst.ClassDef] = []
modified_init_functions: dict[str, cst.FunctionDef] = {}
function_names_set = set(parsed_function_names)
for node in optimized_module.body:
if isinstance(node, cst.FunctionDef):
key = (None, node.name.value)
if key in function_names_set:
modified_functions[key] = node
elif preexisting_objects and (node.name.value, ()) not in preexisting_objects:
new_functions.append(node)
elif isinstance(node, cst.ClassDef):
class_name = node.name.value
parents = (FunctionParent(name=class_name, type="ClassDef"),)
if (class_name, ()) not in preexisting_objects:
new_classes.append(node)
for child in node.body.body:
if isinstance(child, cst.FunctionDef):
method_key = (class_name, child.name.value)
if method_key in function_names_set:
modified_functions[method_key] = child
elif (
child.name.value == "__init__"
and preexisting_objects
and (class_name, ()) in preexisting_objects
):
modified_init_functions[class_name] = child
elif preexisting_objects and (child.name.value, parents) not in preexisting_objects:
new_class_functions[class_name].append(child)
original_module = cst.parse_module(source_code)
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
optimized_module.visit(visitor)
max_function_index = None
max_class_index = None
for index, _node in enumerate(original_module.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
max_class_index = index
# Replace these functions in the original code
transformer = OptimFunctionReplacer(
modified_functions=visitor.modified_functions,
new_classes=visitor.new_classes,
new_functions=visitor.new_functions,
new_class_functions=visitor.new_class_functions,
modified_init_functions=visitor.modified_init_functions,
)
modified_tree = original_module.visit(transformer)
return modified_tree.code
new_body: list[cst.CSTNode] = []
existing_class_names = set()
for node in original_module.body:
if isinstance(node, cst.FunctionDef):
key = (None, node.name.value)
if key in modified_functions:
modified_func = modified_functions[key]
new_body.append(node.with_changes(body=modified_func.body, decorators=modified_func.decorators))
else:
new_body.append(node)
elif isinstance(node, cst.ClassDef):
class_name = node.name.value
existing_class_names.add(class_name)
new_members: list[cst.CSTNode] = []
for child in node.body.body:
if isinstance(child, cst.FunctionDef):
key = (class_name, child.name.value)
if key in modified_functions:
modified_func = modified_functions[key]
new_members.append(
child.with_changes(body=modified_func.body, decorators=modified_func.decorators)
)
elif child.name.value == "__init__" and class_name in modified_init_functions:
new_members.append(modified_init_functions[class_name])
else:
new_members.append(child)
else:
new_members.append(child)
if class_name in new_class_functions:
new_members.extend(new_class_functions[class_name])
new_body.append(node.with_changes(body=node.body.with_changes(body=new_members)))
else:
new_body.append(node)
if new_classes:
unique_classes = [nc for nc in new_classes if nc.name.value not in existing_class_names]
if unique_classes:
new_classes_insertion_idx = (
max_class_index if max_class_index is not None else find_insertion_index_after_imports(original_module)
)
new_body = list(
chain(new_body[:new_classes_insertion_idx], unique_classes, new_body[new_classes_insertion_idx:])
)
if new_functions:
if max_function_index is not None:
new_body = [*new_body[: max_function_index + 1], *new_functions, *new_body[max_function_index + 1 :]]
elif max_class_index is not None:
new_body = [*new_body[: max_class_index + 1], *new_functions, *new_body[max_class_index + 1 :]]
else:
new_body = [*new_functions, *new_body]
updated_module = original_module.with_changes(body=new_body)
return updated_module.code
def replace_functions_and_add_imports(
@ -509,11 +456,8 @@ def replace_function_definitions_for_language(
lang_support = get_language_support(language)
# Add any new global declarations from the optimized code to the original source
original_source_code = _add_global_declarations_for_language(
optimized_code=code_to_apply,
original_source=original_source_code,
module_abspath=module_abspath,
language=language,
original_source_code = lang_support.add_global_declarations(
optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath
)
# If we have function_to_optimize with line info and this is the main file, use it for precise replacement
@ -612,204 +556,6 @@ def _extract_function_from_code(
return None
def _add_global_declarations_for_language(
optimized_code: str, original_source: str, module_abspath: Path, language: Language
) -> str:
"""Add new global declarations from optimized code to original source.
Finds module-level declarations (const, let, var, class, type, interface, enum)
in the optimized code that don't exist in the original source and adds them.
New declarations are inserted after any existing declarations they depend on.
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
is already declared in the original source, `_has` will be inserted after `FOO`.
Args:
optimized_code: The optimized code that may contain new declarations.
original_source: The original source code.
module_abspath: Path to the module file (for parser selection).
language: The language of the code.
Returns:
Original source with new declarations added in dependency order.
"""
from codeflash.languages.base import Language
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
return original_source
try:
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(module_abspath)
original_declarations = analyzer.find_module_level_declarations(original_source)
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
if not optimized_declarations:
return original_source
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
if not new_declarations:
return original_source
# Build a map of existing declaration names to their end lines (1-indexed)
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
# Insert each new declaration after its dependencies
result = original_source
for decl in new_declarations:
result = _insert_declaration_after_dependencies(
result, decl, existing_decl_end_lines, analyzer, module_abspath
)
# Update the map with the newly inserted declaration for subsequent insertions
# Re-parse to get accurate line numbers after insertion
updated_declarations = analyzer.find_module_level_declarations(result)
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
return result
except Exception as e:
logger.debug(f"Error adding global declarations: {e}")
return original_source
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
"""Get all names that already exist in the original source (declarations + imports)."""
existing_names = {decl.name for decl in original_declarations}
original_imports = analyzer.find_imports(original_source)
for imp in original_imports:
if imp.default_import:
existing_names.add(imp.default_import)
for name, alias in imp.named_imports:
existing_names.add(alias if alias else name)
if imp.namespace_import:
existing_names.add(imp.namespace_import)
return existing_names
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
"""Filter declarations to only those that don't exist in the original source."""
new_declarations = []
seen_sources: set[str] = set()
# Sort by line number to maintain order from optimized code
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
for decl in sorted_declarations:
if decl.name not in existing_names and decl.source_code not in seen_sources:
new_declarations.append(decl)
seen_sources.add(decl.source_code)
return new_declarations
def _insert_declaration_after_dependencies(
source: str,
declaration,
existing_decl_end_lines: dict[str, int],
analyzer: TreeSitterAnalyzer,
module_abspath: Path,
) -> str:
"""Insert a declaration after the last existing declaration it depends on.
Args:
source: Current source code.
declaration: The declaration to insert.
existing_decl_end_lines: Map of existing declaration names to their end lines.
analyzer: TreeSitter analyzer.
module_abspath: Path to the module file.
Returns:
Source code with the declaration inserted at the correct position.
"""
# Find identifiers referenced in this declaration
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
# Find the latest end line among all referenced declarations
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
lines = source.splitlines(keepends=True)
# Ensure proper spacing
decl_code = declaration.source_code
if not decl_code.endswith("\n"):
decl_code += "\n"
# Add blank line before if inserting after content
if insertion_line > 0 and lines[insertion_line - 1].strip():
decl_code = "\n" + decl_code
before = lines[:insertion_line]
after = lines[insertion_line:]
return "".join([*before, decl_code, *after])
def _find_insertion_line_for_declaration(
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
) -> int:
"""Find the line where a declaration should be inserted based on its dependencies.
Args:
source: Source code.
referenced_names: Names referenced by the declaration.
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
analyzer: TreeSitter analyzer.
Returns:
Line index (0-based) where the declaration should be inserted.
"""
# Find the maximum end line among referenced declarations
max_dependency_line = 0
for name in referenced_names:
if name in existing_decl_end_lines:
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
if max_dependency_line > 0:
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
return max_dependency_line
# No dependencies found - insert after imports
lines = source.splitlines(keepends=True)
return _find_line_after_imports(lines, analyzer, source)
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
"""Find the line index after all imports.
Args:
lines: Source lines.
analyzer: TreeSitter analyzer.
source: Full source code.
Returns:
Line index (0-based) for insertion after imports.
"""
try:
imports = analyzer.find_imports(source)
if imports:
return max(imp.end_line for imp in imports)
except Exception as exc:
logger.debug(f"Exception in _find_line_after_imports: {exc}")
# Default: insert at beginning (after shebang/directive comments)
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
return i
return 0
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
file_to_code_context = optimized_code.file_to_path()
module_optimized_code = file_to_code_context.get(str(relative_path))
@ -871,8 +617,7 @@ def replace_optimized_code(
[
callee.qualified_name
for callee in code_context.helper_functions
if callee.file_path == module_path
and (callee.jedi_definition is None or callee.jedi_definition.type != "class")
if callee.file_path == module_path and callee.definition_type != "class"
]
),
candidate.source_code,

View file

@ -12,7 +12,6 @@ from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.languages.registry import get_language_support
from codeflash.models.models import GeneratedTests, GeneratedTestsList
from codeflash.result.critic import performance_gain
@ -155,7 +154,6 @@ def _is_python_file(file_path: Path) -> bool:
return file_path.suffix == ".py"
# TODO:{self} Needs cleanup for jest logic in else block
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
unique_inv_ids: dict[str, int] = {}
logger.debug(f"[unique_inv_id] Processing {len(inv_id_runtimes)} invocation IDs")
@ -166,53 +164,11 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_
else inv_id.test_function_name
)
# Detect if test_module_path is a file path (like in js tests) or a Python module name
# File paths contain slashes, module names use dots
test_module_path = inv_id.test_module_path
if "/" in test_module_path or "\\" in test_module_path:
# Already a file path - use directly
abs_path = tests_project_rootdir / Path(test_module_path)
else:
# Check for Jest test file extensions (e.g., tests.fibonacci.test.ts)
# These need special handling to avoid converting .test.ts -> /test/ts
jest_test_extensions = (
".test.ts",
".test.js",
".test.tsx",
".test.jsx",
".spec.ts",
".spec.js",
".spec.tsx",
".spec.jsx",
".ts",
".js",
".tsx",
".jsx",
".mjs",
".mts",
)
matched_ext = None
for ext in jest_test_extensions:
if test_module_path.endswith(ext):
matched_ext = ext
break
if matched_ext:
# JavaScript/TypeScript: convert module-style path to file path
# "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts"
base_path = test_module_path[: -len(matched_ext)]
file_path = base_path.replace(".", os.sep) + matched_ext
# Check if the module path includes the tests directory name
tests_dir_name = tests_project_rootdir.name
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
# Module path includes "tests." - use parent directory
abs_path = tests_project_rootdir.parent / Path(file_path)
else:
# Module path doesn't include tests dir - use tests root directly
abs_path = tests_project_rootdir / Path(file_path)
else:
# Python module name - convert dots to path separators and add .py
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
abs_path_str = str(abs_path.resolve().with_suffix(""))
# Include both unit test and perf test paths for runtime annotations
@ -268,22 +224,7 @@ def add_runtime_comments_to_generated_tests(
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)
else:
try:
language_support = get_language_support(test.behavior_file_path)
modified_source = language_support.add_runtime_comments(
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
)
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)
modified_tests.append(test)
return GeneratedTestsList(generated_tests=modified_tests)
@ -329,109 +270,3 @@ def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.P
)
for func in test_functions_to_remove
]
# Patterns for normalizing codeflash imports (legacy -> npm package)
_CODEFLASH_REQUIRE_PATTERN = re.compile(
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
)
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
def normalize_codeflash_imports(source: str) -> str:
"""Normalize codeflash imports to use the npm package.
Replaces legacy local file imports:
const codeflash = require('./codeflash-jest-helper')
import codeflash from './codeflash-jest-helper'
With npm package imports:
const codeflash = require('codeflash')
Args:
source: JavaScript/TypeScript source code.
Returns:
Source code with normalized imports.
"""
# Replace CommonJS require
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
# Replace ES module import
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList:
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
"""Inject test globals into all generated tests.
Args:
generated_tests: List of generated tests.
test_framework: The test framework being used ("jest", "vitest", or "mocha").
Returns:
Generated tests with test globals injected.
"""
# we only inject test globals for esm modules
# Use vitest imports for vitest projects, jest imports for jest projects
if test_framework == "vitest":
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
else:
# Default to jest imports for jest and other frameworks
global_import = (
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
)
for test in generated_tests.generated_tests:
test.generated_original_test_source = global_import + test.generated_original_test_source
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
return generated_tests
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Disable TypeScript type checking in all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with TypeScript type checking disabled.
"""
# we only inject test globals for esm modules
ts_nocheck = "// @ts-nocheck\n"
for test in generated_tests.generated_tests:
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
return generated_tests
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Normalize codeflash imports in all generated tests.
Args:
generated_tests: List of generated tests.
Returns:
Generated tests with normalized imports.
"""
normalized_tests = []
for test in generated_tests.generated_tests:
# Only normalize JS/TS files
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
normalized_test = GeneratedTests(
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
normalized_tests.append(normalized_test)
else:
normalized_tests.append(test)
return GeneratedTestsList(generated_tests=normalized_tests)

View file

@ -21,7 +21,8 @@ from codeflash.languages.registry import register_language
if TYPE_CHECKING:
from collections.abc import Sequence
from codeflash.models.models import FunctionSource
from codeflash.languages.base import DependencyResolver
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
logger = logging.getLogger(__name__)
@ -75,6 +76,37 @@ class PythonSupport:
def comment_prefix(self) -> str:
return "#"
@property
def dir_excludes(self) -> frozenset[str]:
return frozenset(
{
"__pycache__",
".venv",
"venv",
".tox",
".nox",
".eggs",
".mypy_cache",
".ruff_cache",
".pytest_cache",
".hypothesis",
"htmlcov",
".pytype",
".pyre",
".pybuilder",
".ipynb_checkpoints",
".codeflash",
".cache",
".complexipy_cache",
"build",
"dist",
"sdist",
".coverage*",
".pyright*",
"*.egg-info",
}
)
# === Discovery ===
def discover_functions(
@ -347,7 +379,7 @@ class PythonSupport:
Modified source code with function replaced.
"""
from codeflash.code_utils.code_replacer import replace_functions_in_file
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file
try:
# Determine the function names to replace
@ -625,6 +657,59 @@ class PythonSupport:
except Exception:
return test_source
def postprocess_generated_tests(
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
) -> GeneratedTestsList:
"""Apply language-specific postprocessing to generated tests."""
_ = test_framework, project_root, source_file_path
return generated_tests
def remove_test_functions_from_generated_tests(
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
) -> GeneratedTestsList:
"""Remove specific test functions from generated tests."""
from codeflash.languages.python.static_analysis.edit_generated_tests import (
remove_functions_from_generated_tests,
)
return remove_functions_from_generated_tests(generated_tests, functions_to_remove)
def add_runtime_comments_to_generated_tests(
self,
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
tests_project_rootdir: Path | None = None,
) -> GeneratedTestsList:
"""Add runtime comments to generated tests."""
from codeflash.languages.python.static_analysis.edit_generated_tests import (
add_runtime_comments_to_generated_tests,
)
return add_runtime_comments_to_generated_tests(
generated_tests, original_runtimes, optimized_runtimes, tests_project_rootdir
)
def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str:
_ = optimized_code, module_abspath
return original_source
def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None:
"""Extract the source code of a calling function in Python."""
try:
import ast
lines = source_code.splitlines()
tree = ast.parse(source_code)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
end_line = node.end_lineno or node.lineno
if node.lineno <= ref_line <= end_line:
return "\n".join(lines[node.lineno - 1 : end_line])
except Exception:
return None
return None
# === Test Result Comparison ===
def compare_test_results(
@ -720,6 +805,15 @@ class PythonSupport:
"""
return True
def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None:
from codeflash.languages.python.reference_graph import ReferenceGraph
try:
return ReferenceGraph(project_root, language=self.language.value)
except Exception:
logger.debug("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis")
return None
def instrument_existing_test(
self,
test_path: Path,

View file

@ -18,6 +18,11 @@ def is_LSP_enabled() -> bool:
return os.getenv("CODEFLASH_LSP", default="false").lower() == "true"
@lru_cache(maxsize=1)
def is_subagent_mode() -> bool:
return os.getenv("CODEFLASH_SUBAGENT_MODE", default="false").lower() == "true"
def tree_to_markdown(tree: Tree, level: int = 0) -> str:
"""Convert a rich Tree into a Markdown bullet list."""
indent = " " * level

View file

@ -11,6 +11,12 @@ import sys
from pathlib import Path
from typing import TYPE_CHECKING
if "--subagent" in sys.argv:
os.environ["CODEFLASH_SUBAGENT_MODE"] = "true"
import warnings
warnings.filterwarnings("ignore")
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
from codeflash.cli_cmds.console import paneled_text

View file

@ -25,7 +25,6 @@ from pathlib import Path
from re import Pattern
from typing import Any, NamedTuple, Optional, cast
from jedi.api.classes import Name
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
from pydantic.dataclasses import dataclass
@ -136,14 +135,14 @@ class CoverReturnCode(IntEnum):
ERROR = 2
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@dataclass(frozen=True)
class FunctionSource:
file_path: Path
qualified_name: str
fully_qualified_name: str
only_function_name: str
source_code: str
jedi_definition: Name | None = None # None for non-Python languages
definition_type: str | None = None # e.g. "function", "class"; None for non-Python languages
def __eq__(self, other: object) -> bool:
if not isinstance(other, FunctionSource):
@ -380,6 +379,7 @@ class CodeOptimizationContext(BaseModel):
hashing_code_context: str = ""
hashing_code_context_hash: str = ""
helper_functions: list[FunctionSource]
testgen_helper_fqns: list[str] = []
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import ast
import concurrent.futures
import dataclasses
import logging
import os
import queue
@ -23,14 +24,15 @@ from rich.tree import Tree
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code
from codeflash.code_utils.code_replacer import (
add_custom_marker_to_all_tests,
modify_autouse_fixture,
replace_function_definitions_in_module,
from codeflash.cli_cmds.console import (
code_print,
console,
logger,
lsp_log,
progress_bar,
subagent_log_optimization_result,
)
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import (
choose_weights,
cleanup_paths,
@ -58,34 +60,32 @@ from codeflash.code_utils.config_consts import (
get_effort_value,
)
from codeflash.code_utils.deduplicate_code import normalize_code
from codeflash.code_utils.edit_generated_tests import (
add_runtime_comments_to_generated_tests,
disable_ts_check,
inject_test_globals,
normalize_generated_tests_imports,
remove_functions_from_generated_tests,
)
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.code_utils.shell_utils import make_env_with_project_root
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.languages import is_python
from codeflash.languages.base import Language
from codeflash.languages.current import current_language_support, is_typescript
from codeflash.languages.javascript.module_system import detect_module_system
from codeflash.languages.current import current_language_support
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
from codeflash.languages.python.context import code_context_extractor
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
from codeflash.languages.python.static_analysis.code_replacer import (
add_custom_marker_to_all_tests,
modify_autouse_fixture,
replace_function_definitions_in_module,
)
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode, report_to_markdown_table, tree_to_markdown
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
@ -139,6 +139,7 @@ if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Result
from codeflash.languages.base import DependencyResolver
from codeflash.models.models import (
BenchmarkKey,
CodeStringsMarkdown,
@ -442,10 +443,14 @@ class FunctionOptimizer:
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
args: Namespace | None = None,
replay_tests_dir: Path | None = None,
call_graph: DependencyResolver | None = None,
) -> None:
self.project_root = test_cfg.project_root_path
self.project_root = test_cfg.project_root_path.resolve()
self.test_cfg = test_cfg
self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient()
resolved_file_path = function_to_optimize.file_path.resolve()
if resolved_file_path != function_to_optimize.file_path:
function_to_optimize = dataclasses.replace(function_to_optimize, file_path=resolved_file_path)
self.function_to_optimize = function_to_optimize
self.function_to_optimize_source_code = (
function_to_optimize_source_code
@ -484,6 +489,7 @@ class FunctionOptimizer:
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
self.call_graph = call_graph
n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort)
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
@ -579,6 +585,7 @@ class FunctionOptimizer:
test_results = self.generate_tests(
testgen_context=code_context.testgen_context,
helper_functions=code_context.helper_functions,
testgen_helper_fqns=code_context.testgen_helper_fqns,
generated_test_paths=generated_test_paths,
generated_perf_test_paths=generated_perf_test_paths,
)
@ -588,16 +595,13 @@ class FunctionOptimizer:
count_tests, generated_tests, function_to_concolic_tests, concolic_test_str = test_results.unwrap()
# Normalize codeflash imports in JS/TS tests to use npm package
if not is_python():
module_system = detect_module_system(self.project_root, self.function_to_optimize.file_path)
if module_system == "esm":
generated_tests = inject_test_globals(generated_tests, self.test_cfg.test_framework)
if is_typescript():
# disable ts check for typescript tests
generated_tests = disable_ts_check(generated_tests)
generated_tests = normalize_generated_tests_imports(generated_tests)
# Language-specific postprocessing for generated tests
generated_tests = self.language_support.postprocess_generated_tests(
generated_tests,
test_framework=self.test_cfg.test_framework,
project_root=self.project_root,
source_file_path=self.function_to_optimize.file_path,
)
logger.debug(f"[PIPELINE] Processing {count_tests} generated tests")
for i, generated_test in enumerate(generated_tests.generated_tests):
@ -1352,6 +1356,8 @@ class FunctionOptimizer:
def log_successful_optimization(
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
) -> None:
if is_subagent_mode():
return
if is_LSP_enabled():
md_lines = [
"### ⚡️ Optimization Summary",
@ -1450,7 +1456,7 @@ class FunctionOptimizer:
optimized_code = ""
if optimized_context is not None:
file_to_code_context = optimized_context.file_to_path()
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
optimized_code = file_to_code_context.get(str(path.resolve().relative_to(self.project_root)), "")
new_code = format_code(
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
@ -1487,8 +1493,8 @@ class FunctionOptimizer:
self.function_to_optimize.qualified_name
)
for helper_function in code_context.helper_functions:
# Skip class definitions (jedi_definition may be None for non-Python languages)
if helper_function.jedi_definition is None or helper_function.jedi_definition.type != "class":
# Skip class definitions (definition_type may be None for non-Python languages)
if helper_function.definition_type != "class":
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
did_update |= replace_function_definitions_in_module(
@ -1509,7 +1515,7 @@ class FunctionOptimizer:
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
try:
new_code_ctx = code_context_extractor.get_code_optimization_context(
self.function_to_optimize, self.project_root
self.function_to_optimize, self.project_root, call_graph=self.call_graph
)
except ValueError as e:
return Failure(str(e))
@ -1521,7 +1527,8 @@ class FunctionOptimizer:
read_only_context_code=new_code_ctx.read_only_context_code,
hashing_code_context=new_code_ctx.hashing_code_context,
hashing_code_context_hash=new_code_ctx.hashing_code_context_hash,
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
helper_functions=new_code_ctx.helper_functions,
testgen_helper_fqns=new_code_ctx.testgen_helper_fqns,
preexisting_objects=new_code_ctx.preexisting_objects,
)
)
@ -1727,6 +1734,7 @@ class FunctionOptimizer:
self,
testgen_context: CodeStringsMarkdown,
helper_functions: list[FunctionSource],
testgen_helper_fqns: list[str],
generated_test_paths: list[Path],
generated_perf_test_paths: list[Path],
) -> Result[tuple[int, GeneratedTestsList, dict[str, set[FunctionCalledInTest]], str], str]:
@ -1735,23 +1743,29 @@ class FunctionOptimizer:
assert len(generated_test_paths) == n_tests
if not self.args.no_gen_tests:
# Submit test generation tasks
helper_fqns = testgen_helper_fqns or [definition.fully_qualified_name for definition in helper_functions]
future_tests = self.submit_test_generation_tasks(
self.executor,
testgen_context.markdown,
[definition.fully_qualified_name for definition in helper_functions],
generated_test_paths,
generated_perf_test_paths,
self.executor, testgen_context.markdown, helper_fqns, generated_test_paths, generated_perf_test_paths
)
future_concolic_tests = self.executor.submit(
generate_concolic_tests, self.test_cfg, self.args, self.function_to_optimize, self.function_to_optimize_ast
)
if is_subagent_mode():
future_concolic_tests = None
else:
future_concolic_tests = self.executor.submit(
generate_concolic_tests,
self.test_cfg,
self.args,
self.function_to_optimize,
self.function_to_optimize_ast,
)
if not self.args.no_gen_tests:
# Wait for test futures to complete
concurrent.futures.wait([*future_tests, future_concolic_tests])
else:
futures_to_wait = [*future_tests]
if future_concolic_tests is not None:
futures_to_wait.append(future_concolic_tests)
concurrent.futures.wait(futures_to_wait)
elif future_concolic_tests is not None:
concurrent.futures.wait([future_concolic_tests])
# Process test generation results
tests: list[GeneratedTests] = []
@ -1780,7 +1794,10 @@ class FunctionOptimizer:
logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}")
return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}")
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
if future_concolic_tests is not None:
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
else:
function_to_concolic_tests, concolic_test_str = {}, None
count_tests = len(tests)
if concolic_test_str:
count_tests += 1
@ -2061,8 +2078,8 @@ class FunctionOptimizer:
else "Coverage data not available"
)
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
generated_tests = self.language_support.remove_test_functions_from_generated_tests(
generated_tests, test_functions_to_remove
)
map_gen_test_file_to_no_of_tests = original_code_baseline.behavior_test_results.file_to_no_of_tests(
test_functions_to_remove
@ -2073,7 +2090,7 @@ class FunctionOptimizer:
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
)
generated_tests = add_runtime_comments_to_generated_tests(
generated_tests = self.language_support.add_runtime_comments_to_generated_tests(
generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir
)
@ -2203,7 +2220,20 @@ class FunctionOptimizer:
self.optimization_review = opt_review_result.review
# Display the reviewer result to the user
if opt_review_result.review:
if is_subagent_mode():
subagent_log_optimization_result(
function_name=new_explanation.function_name,
file_path=new_explanation.file_path,
perf_improvement_line=new_explanation.perf_improvement_line,
original_runtime_ns=new_explanation.original_runtime_ns,
best_runtime_ns=new_explanation.best_runtime_ns,
raw_explanation=new_explanation.raw_explanation_message,
original_code=original_code_combined,
new_code=new_code_combined,
review=opt_review_result.review,
test_results=new_explanation.winning_behavior_test_results,
)
elif opt_review_result.review:
review_display = {
"high": ("[bold green]High[/bold green]", "green", "Recommended to merge"),
"medium": ("[bold yellow]Medium[/bold yellow]", "yellow", "Review recommended before merging"),

View file

@ -11,7 +11,13 @@ from typing import TYPE_CHECKING
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.api.cfapi import send_completion_email
from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.cli_cmds.console import ( # noqa: F401
call_graph_live_display,
call_graph_summary,
console,
logger,
progress_bar,
)
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
@ -24,7 +30,8 @@ from codeflash.code_utils.git_worktree_utils import (
)
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.either import is_successful
from codeflash.languages import is_javascript, set_current_language
from codeflash.languages import current_language_support, is_javascript, set_current_language
from codeflash.lsp.helpers import is_subagent_mode
from codeflash.models.models import ValidCode
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
@ -35,6 +42,7 @@ if TYPE_CHECKING:
from codeflash.benchmarking.function_ranker import FunctionRanker
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import DependencyResolver
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
from codeflash.optimization.function_optimizer import FunctionOptimizer
@ -241,8 +249,11 @@ class Optimizer:
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
original_module_ast: ast.Module | None = None,
original_module_path: Path | None = None,
call_graph: DependencyResolver | None = None,
) -> FunctionOptimizer | None:
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.languages.python.static_analysis.static_analysis import (
get_first_top_level_function_or_method_ast,
)
from codeflash.optimization.function_optimizer import FunctionOptimizer
if function_to_optimize_ast is None and original_module_ast is not None:
@ -279,13 +290,14 @@ class Optimizer:
function_benchmark_timings=function_specific_timings,
total_benchmark_timings=total_benchmark_timings if function_specific_timings else None,
replay_tests_dir=self.replay_tests_dir,
call_graph=call_graph,
)
def prepare_module_for_optimization(
self, original_module_path: Path
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
from codeflash.code_utils.static_analysis import analyze_imported_modules
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
logger.info(f"loading|Examining file {original_module_path!s}")
console.rule()
@ -422,7 +434,10 @@ class Optimizer:
console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]")
def rank_all_functions_globally(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None
self,
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]],
trace_file_path: Path | None,
call_graph: DependencyResolver | None = None,
) -> list[tuple[Path, FunctionToOptimize]]:
"""Rank all functions globally across all files based on trace data.
@ -442,8 +457,10 @@ class Optimizer:
for file_path, functions in file_to_funcs_to_optimize.items():
all_functions.extend((file_path, func) for func in functions)
# If no trace file, return in original order
# If no trace file, rank by dependency count if call graph is available
if not trace_file_path or not trace_file_path.exists():
if call_graph is not None:
return self.rank_by_dependency_count(all_functions, call_graph)
logger.debug("No trace file available, using original function order")
return all_functions
@ -494,6 +511,19 @@ class Optimizer:
else:
return globally_ranked
def rank_by_dependency_count(
self, all_functions: list[tuple[Path, FunctionToOptimize]], call_graph: DependencyResolver
) -> list[tuple[Path, FunctionToOptimize]]:
file_to_qns: dict[Path, set[str]] = defaultdict(set)
for file_path, func in all_functions:
file_to_qns[file_path].add(func.qualified_name)
callee_counts = call_graph.count_callees_per_function(dict(file_to_qns))
ranked = sorted(
enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0])
)
logger.debug(f"Ranked {len(ranked)} functions by dependency count (most complex first)")
return [item for _, item in ranked]
def run(self) -> None:
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
@ -536,16 +566,33 @@ class Optimizer:
if self.args.all:
three_min_in_ns = int(1.8e11)
console.rule()
pr_message = (
"\nCodeflash will keep opening pull requests as it finds optimizations." if not self.args.no_pr else ""
)
logger.info(
f"It might take about {humanize_runtime(num_optimizable_functions * three_min_in_ns)} to fully optimize this project.{pr_message}"
f"It might take about {humanize_runtime(num_optimizable_functions * three_min_in_ns)} to fully optimize this project."
)
if not self.args.no_pr:
logger.info("Codeflash will keep opening pull requests as it finds optimizations.")
console.rule()
function_benchmark_timings, total_benchmark_timings = self.run_benchmarks(
file_to_funcs_to_optimize, num_optimizable_functions
)
# Create a language-specific dependency resolver (e.g. Jedi-based call graph for Python)
# Skip in CI — the cache DB doesn't persist between runs on ephemeral runners
lang_support = current_language_support()
resolver = None
# CURRENTLY DISABLED: The resolver is currently not used for anything until i clean up the repo structure for python
# if lang_support and not env_utils.is_ci():
# resolver = lang_support.create_dependency_resolver(self.args.project_root)
# if resolver is not None and lang_support is not None and file_to_funcs_to_optimize:
# supported_exts = lang_support.file_extensions
# source_files = [f for f in file_to_funcs_to_optimize if f.suffix in supported_exts]
# with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress:
# resolver.build_index(source_files, on_progress=on_progress)
# console.rule()
# call_graph_summary(resolver, file_to_funcs_to_optimize)
optimizations_found: int = 0
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
@ -557,11 +604,13 @@ class Optimizer:
return
function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize)
if self.args.all:
if self.args.all and not getattr(self.args, "agent", False):
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root)
# GLOBAL RANKING: Rank all functions together before optimizing
globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path)
globally_ranked_functions = self.rank_all_functions_globally(
file_to_funcs_to_optimize, trace_file_path, call_graph=resolver
)
# Cache for module preparation (avoid re-parsing same files)
prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {}
@ -593,6 +642,7 @@ class Optimizer:
total_benchmark_timings=total_benchmark_timings,
original_module_ast=original_module_ast,
original_module_path=original_module_path,
call_graph=resolver,
)
if function_optimizer is None:
continue
@ -608,7 +658,7 @@ class Optimizer:
if is_successful(best_optimization):
optimizations_found += 1
# create a diff patch for successful optimization
if self.current_worktree:
if self.current_worktree and not is_subagent_mode():
best_opt = best_optimization.unwrap()
read_writable_code = best_opt.code_context.read_writable_code
relative_file_paths = [
@ -641,7 +691,12 @@ class Optimizer:
self.functions_checkpoint.cleanup()
if hasattr(self.args, "command") and self.args.command == "optimize":
self.cleanup_replay_tests()
if optimizations_found == 0:
if is_subagent_mode():
if optimizations_found == 0:
import sys
sys.stdout.write("<codeflash-summary>No optimizations found.</codeflash-summary>\n")
elif optimizations_found == 0:
logger.info("❌ No optimizations found.")
elif self.args.all:
logger.info("✨ All functions have been optimized! ✨")
@ -651,6 +706,9 @@ class Optimizer:
else:
logger.warning("⚠️ Failed to send completion email. Status")
finally:
if resolver is not None:
resolver.close()
if function_optimizer:
function_optimizer.cleanup_generated_files()

View file

@ -9,12 +9,12 @@ import git
from codeflash.api import cfapi
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import is_zero_diff
from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import github_pr_url
from codeflash.code_utils.tabulate import tabulate
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.languages.python.static_analysis.code_replacer import is_zero_diff
from codeflash.result.critic import performance_gain
if TYPE_CHECKING:
@ -126,10 +126,10 @@ def existing_tests_source_for(
tests_dir_name = test_cfg.tests_project_rootdir.name
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
# Module path includes "tests." - use project root parent
instrumented_abs_path = (test_cfg.tests_project_rootdir.parent / file_path).resolve()
instrumented_abs_path = test_cfg.tests_project_rootdir.parent / file_path
else:
# Module path doesn't include tests dir - use tests root directly
instrumented_abs_path = (test_cfg.tests_project_rootdir / file_path).resolve()
instrumented_abs_path = test_cfg.tests_project_rootdir / file_path
logger.debug(f"[PR-DEBUG] Looking up: {instrumented_abs_path}")
logger.debug(f"[PR-DEBUG] Available keys: {list(instrumented_to_original.keys())[:3]}")
# Try to map instrumented path to original path

View file

@ -2,6 +2,7 @@ import logging
import sentry_sdk
from sentry_sdk.integrations.logging import LoggingIntegration
from sentry_sdk.integrations.stdlib import StdlibIntegration
def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
@ -16,12 +17,8 @@ def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
sentry_sdk.init(
dsn="https://4b9a1902f9361b48c04376df6483bc96@o4506833230561280.ingest.sentry.io/4506833262477312",
integrations=[sentry_logging],
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for performance monitoring.
traces_sample_rate=1.0,
# Set profiles_sample_rate to 1.0 to profile 100%
# of sampled transactions.
# We recommend adjusting this value in production.
profiles_sample_rate=1.0,
disabled_integrations=[StdlibIntegration],
traces_sample_rate=0,
profiles_sample_rate=0,
ignore_errors=[KeyboardInterrupt],
)

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ast
import importlib.util
import subprocess
import tempfile
import time
@ -9,15 +10,17 @@ from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.concolic_utils import clean_concolic_tests, is_valid_concolic_test
from codeflash.code_utils.shell_utils import make_env_with_project_root
from codeflash.code_utils.static_analysis import has_typed_parameters
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests, is_valid_concolic_test
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
CROSSHAIR_AVAILABLE = importlib.util.find_spec("crosshair") is not None
if TYPE_CHECKING:
from argparse import Namespace
@ -52,6 +55,10 @@ def generate_concolic_tests(
logger.debug("Skipping concolic test generation for non-Python languages (CrossHair is Python-only)")
return function_to_concolic_tests, concolic_test_suite_code
if not CROSSHAIR_AVAILABLE:
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
return function_to_concolic_tests, concolic_test_suite_code
if is_LSP_enabled():
logger.debug("Skipping concolic test generation in LSP mode")
return function_to_concolic_tests, concolic_test_suite_code

View file

@ -7,7 +7,7 @@ import sentry_sdk
from coverage.exceptions import NoDataError
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.coverage_utils import (
from codeflash.languages.python.static_analysis.coverage_utils import (
build_fully_qualified_name,
extract_dependent_function,
generate_candidates,

View file

@ -47,8 +47,24 @@ def parse_func(file_path: Path) -> XMLParser:
return parse(file_path, xml_parser)
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
matches_re_start = re.compile(
r"!\$######([^:]*)" # group 1: module path
r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty
r"([^.:]*)" # group 3: test function name
r":([^:]*)" # group 4: function being tested
r":([^:]*)" # group 5: loop index
r":([^#]*)" # group 6: iteration id
r"######\$!\n"
)
matches_re_end = re.compile(
r"!######([^:]*)" # group 1: module path
r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty
r"([^.:]*)" # group 3: test function name
r":([^:]*)" # group 4: function being tested
r":([^:]*)" # group 5: loop index
r":([^#]*)" # group 6: iteration_id or iteration_id:runtime
r"######!"
)
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
@ -893,7 +909,6 @@ def merge_test_results(
return merged_test_results
FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+")
TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$")
@ -903,7 +918,7 @@ def parse_test_failures_from_stdout(stdout: str) -> dict[str, str]:
start = end = None
for i, line in enumerate(lines):
if FAILURES_HEADER_RE.search(line.strip()):
if "= FAILURES =" in line:
start = i
break

View file

@ -13,9 +13,9 @@ from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.code_utils.coverage_utils import prepare_coverage_files
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
from codeflash.languages.registry import get_language_support, get_language_support_by_framework
from codeflash.models.models import TestFiles, TestType

View file

@ -158,6 +158,10 @@ class TestConfig:
_language: Optional[str] = None # Language identifier for multi-language support
js_project_root: Optional[Path] = None # JavaScript project root (directory containing package.json)
def __post_init__(self) -> None:
self.project_root_path = self.project_root_path.resolve()
self.tests_project_rootdir = self.tests_project_rootdir.resolve()
@property
def test_framework(self) -> str:
"""Returns the appropriate test framework based on language.

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.20.0"
__version__ = "0.20.1"

View file

@ -28,8 +28,8 @@ codeflash/code_utils/__init__.py
codeflash/code_utils/time_utils.py
codeflash/code_utils/env_utils.py
codeflash/code_utils/config_consts.py
codeflash/code_utils/static_analysis.py
codeflash/code_utils/edit_generated_tests.py
codeflash/languages/python/static_analysis/static_analysis.py
codeflash/languages/python/static_analysis/edit_generated_tests.py
codeflash/cli_cmds/console_constants.py
codeflash/cli_cmds/logging_config.py
codeflash/cli_cmds/__init__.py

View file

@ -39,13 +39,14 @@ dependencies = [
"dill>=0.3.8",
"rich>=13.8.1",
"lxml>=5.3.0",
"crosshair-tool>=0.0.78",
"crosshair-tool>=0.0.78; python_version < '3.15'",
"coverage>=7.6.4",
"line_profiler>=4.2.0",
"platformdirs>=4.3.7",
"pygls>=2.0.0,<3.0.0",
"codeflash-benchmark",
"filelock",
"filelock>=3.20.3; python_version >= '3.10'",
"filelock<3.20.3; python_version < '3.10'",
"pytest-asyncio>=0.18.0",
]
@ -94,7 +95,7 @@ tests = [
"xarray>=2024.7.0",
"eval_type_backport",
"numba>=0.60.0",
"tensorflow>=2.20.0",
"tensorflow>=2.20.0; python_version >= '3.10'",
]
[tool.hatch.build.targets.sdist]
@ -207,6 +208,8 @@ warn_unreachable = true
install_types = true
plugins = ["pydantic.mypy"]
exclude = ["tests/", "code_to_optimize/", "pie_test_set/", "experiments/"]
[[tool.mypy.overrides]]
module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba"]
ignore_missing_imports = true
@ -310,6 +313,9 @@ split-on-trailing-comma = false
docstring-code-format = true
skip-magic-trailing-comma = true
[tool.ty.src]
exclude = ["tests", "code_to_optimize", "pie_test_set", "experiments"]
[tool.hatch.version]
source = "uv-dynamic-versioning"

View file

@ -20,7 +20,7 @@
"version": "0.13.0"
},
"tessl/pypi-pydantic": {
"version": "1.10.0"
"version": "2.11.0"
},
"tessl/pypi-humanize": {
"version": "4.13.0"
@ -35,7 +35,7 @@
"version": "3.4.0"
},
"tessl/pypi-sentry-sdk": {
"version": "1.45.0"
"version": "2.36.0"
},
"tessl/pypi-parameterized": {
"version": "0.9.0"
@ -44,10 +44,10 @@
"version": "0.4.0"
},
"tessl/pypi-rich": {
"version": "13.9.0"
"version": "14.1.0"
},
"tessl/pypi-lxml": {
"version": "5.4.0"
"version": "6.0.0"
},
"tessl/pypi-crosshair-tool": {
"version": "0.0.0"
@ -64,17 +64,20 @@
"tessl/pypi-filelock": {
"version": "3.19.0"
},
"codeflash/codeflash-rules": {
"version": "0.1.0"
"tessl/pypi-ipython": {
"version": "9.5.0"
},
"codeflash/codeflash-docs": {
"version": "0.1.0"
"tessl/pypi-mypy": {
"version": "1.17.0"
},
"codeflash/codeflash-skills": {
"version": "0.2.0"
"tessl/pypi-ty": {
"version": "0.0.0"
},
"tessl-labs/tessl-skill-eval-scenarios": {
"version": "0.0.5"
"tessl/pypi-types-jsonschema": {
"version": "3.2.0"
},
"tessl/pypi-uv": {
"version": "0.8.0"
}
}
}

View file

@ -2,7 +2,7 @@ from __future__ import annotations
import pytest
from codeflash.code_utils.concolic_utils import AssertCleanup, is_valid_concolic_test
from codeflash.languages.python.static_analysis.concolic_utils import AssertCleanup, is_valid_concolic_test
class TestFirstTopLevelArg:

View file

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Any
from codeflash.code_utils.coverage_utils import build_fully_qualified_name, extract_dependent_function
from codeflash.languages.python.static_analysis.coverage_utils import build_fully_qualified_name, extract_dependent_function
from codeflash.models.function_types import FunctionParent
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown
from codeflash.verification.coverage_utils import CoverageUtils

View file

@ -3,13 +3,13 @@ from pathlib import Path
import libcst as cst
from codeflash.code_utils.code_extractor import (
from codeflash.languages.python.static_analysis.code_extractor import (
DottedImportCollector,
add_needed_imports_from_module,
find_preexisting_objects,
resolve_star_import,
)
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
from codeflash.models.models import FunctionParent
@ -22,7 +22,7 @@ import jedi
import tiktoken
from jedi.api.classes import Name
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
@ -76,7 +76,7 @@ import tiktoken
from jedi.api.classes import Name
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize

View file

@ -4,7 +4,7 @@ from unittest.mock import Mock
import pytest
from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests
from codeflash.languages.python.static_analysis.edit_generated_tests import add_runtime_comments_to_generated_tests
from codeflash.models.models import (
FunctionTestInvocation,
GeneratedTests,

View file

@ -9,8 +9,8 @@ from pathlib import Path
import pytest
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import (
collect_type_names_from_annotation,
@ -768,11 +768,204 @@ class HelperClass:
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_3(tmp_path: Path) -> None:
def test_example_class_token_limit_1(tmp_path: Path) -> None:
docstring_filler = " ".join(
["This is a long docstring that will be used to fill up the token limit." for _ in range(4000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method.
{docstring_filler}\"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
# Create a temporary Python file using pytest's tmp_path fixture
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
```
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
pass
class HelperClass:
def __repr__(self):
return "HelperClass" + str(self.x)
```
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_2(tmp_path: Path) -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
# Create a temporary Python file using pytest's tmp_path fixture
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
```
"""
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
"""A class with a helper method. """
class HelperClass:
"""A helper class for MyClass."""
def __repr__(self):
"""Return a string representation of the HelperClass."""
return "HelperClass" + str(self.x)
```
'''
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_3(tmp_path: Path) -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(4000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
@ -820,7 +1013,7 @@ class HelperClass:
def test_example_class_token_limit_4(tmp_path: Path) -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
["This is a long string that will be used to fill up the token limit." for _ in range(4000)]
)
code = f"""
class MyClass:
@ -979,7 +1172,12 @@ def test_repo_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
path_to_globals = project_root / "globals.py"
expected_read_write_context = f"""
```python:{path_to_globals.relative_to(project_root)}
# Define a global variable
API_URL = "https://api.example.com/data"
```
```python:{path_to_utils.relative_to(project_root)}
import math
@ -1072,7 +1270,12 @@ def test_repo_helper_of_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
path_to_globals = project_root / "globals.py"
expected_read_write_context = f"""
```python:{path_to_globals.relative_to(project_root)}
# Define a global variable
API_URL = "https://api.example.com/data"
```
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1799,6 +2002,8 @@ class Calculator:
"""
expected_read_only_context = """
```python:utility_module.py
import sys
DEFAULT_PRECISION = "medium"
# Try-except block with variable definitions
@ -1809,6 +2014,17 @@ except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Nested if-else with variable definitions
if sys.platform.startswith('win'):
# Used variable in outer if
SYSTEM_TYPE = "windows"
elif sys.platform.startswith('linux'):
# Used variable in outer elif
SYSTEM_TYPE = "linux"
else:
# Used variable in outer else
SYSTEM_TYPE = "other"
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
if precision is None:
@ -2015,6 +2231,8 @@ def get_system_details():
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f"""
```python:utility_module.py
import sys
DEFAULT_PRECISION = "medium"
# Try-except block with variable definitions
@ -2025,6 +2243,17 @@ except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Nested if-else with variable definitions
if sys.platform.startswith('win'):
# Used variable in outer if
SYSTEM_TYPE = "windows"
elif sys.platform.startswith('linux'):
# Used variable in outer elif
SYSTEM_TYPE = "linux"
else:
# Used variable in outer else
SYSTEM_TYPE = "other"
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
if precision is None:
@ -2065,6 +2294,8 @@ class Calculator:
"""
expected_read_only_context = """
```python:utility_module.py
import sys
DEFAULT_PRECISION = "medium"
# Try-except block with variable definitions
@ -2074,6 +2305,17 @@ try:
except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Nested if-else with variable definitions
if sys.platform.startswith('win'):
# Used variable in outer if
SYSTEM_TYPE = "windows"
elif sys.platform.startswith('linux'):
# Used variable in outer elif
SYSTEM_TYPE = "linux"
else:
# Used variable in outer else
SYSTEM_TYPE = "other"
```
"""
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
@ -2629,7 +2871,7 @@ def test_global_function_collector():
"""Test GlobalFunctionCollector correctly collects module-level function definitions."""
import libcst as cst
from codeflash.code_utils.code_extractor import GlobalFunctionCollector
from codeflash.languages.python.static_analysis.code_extractor import GlobalFunctionCollector
source_code = """
# Module-level functions

View file

@ -1,7 +1,7 @@
import tempfile
from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.languages.python.static_analysis.code_extractor import add_needed_imports_from_module
def test_add_needed_imports_with_none_aliases():

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import dataclasses
import os
import re
from collections import defaultdict
@ -8,39 +7,23 @@ from pathlib import Path
import libcst as cst
from codeflash.code_utils.code_extractor import delete___future___aliased_imports, find_preexisting_objects
from codeflash.code_utils.code_replacer import (
from codeflash.languages.python.static_analysis.code_extractor import delete___future___aliased_imports, find_preexisting_objects
from codeflash.languages.python.static_analysis.code_replacer import (
AddRequestArgument,
AutouseFixtureModifier,
OptimFunctionCollector,
PytestMarkAdder,
is_zero_diff,
replace_functions_and_add_imports,
replace_functions_in_file,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@dataclasses.dataclass
class JediDefinition:
type: str
@dataclasses.dataclass
class FakeFunctionSource:
file_path: Path
qualified_name: str
fully_qualified_name: str
only_function_name: str
source_code: str
jedi_definition: JediDefinition
class Args:
disable_imports_sorting = True
formatter_cmds = ["disabled"]
@ -1137,7 +1120,7 @@ class TestResults(BaseModel):
preexisting_objects = find_preexisting_objects(original_code)
helper_functions = [
FakeFunctionSource(
FunctionSource(
file_path=Path(
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
),
@ -1145,7 +1128,7 @@ class TestResults(BaseModel):
fully_qualified_name="codeflash.verification.test_results.TestType",
only_function_name="TestType",
source_code="",
jedi_definition=JediDefinition(type="class"),
definition_type="class",
)
]
@ -1160,7 +1143,7 @@ class TestResults(BaseModel):
helper_functions_by_module_abspath = defaultdict(set)
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
if helper_function.definition_type != "class":
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
new_code: str = replace_functions_and_add_imports(
@ -1352,21 +1335,21 @@ def cosine_similarity_top_k(
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
helper_functions = [
FakeFunctionSource(
FunctionSource(
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
qualified_name="Matrix",
fully_qualified_name="code_to_optimize.math_utils.Matrix",
only_function_name="Matrix",
source_code="",
jedi_definition=JediDefinition(type="class"),
definition_type="class",
),
FakeFunctionSource(
FunctionSource(
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
qualified_name="cosine_similarity",
fully_qualified_name="code_to_optimize.math_utils.cosine_similarity",
only_function_name="cosine_similarity",
source_code="",
jedi_definition=JediDefinition(type="function"),
definition_type="function",
),
]
@ -1425,7 +1408,7 @@ def cosine_similarity_top_k(
)
helper_functions_by_module_abspath = defaultdict(set)
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
if helper_function.definition_type != "class":
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
new_helper_code: str = replace_functions_and_add_imports(
@ -3492,142 +3475,6 @@ def hydrate_input_text_actions_with_field_names(
assert new_code == expected
# OptimFunctionCollector async function tests
def test_optim_function_collector_with_async_functions():
"""Test OptimFunctionCollector correctly collects async functions."""
import libcst as cst
source_code = """
def sync_function():
return "sync"
async def async_function():
return "async"
class TestClass:
def sync_method(self):
return "sync_method"
async def async_method(self):
return "async_method"
"""
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names={
(None, "sync_function"),
(None, "async_function"),
("TestClass", "sync_method"),
("TestClass", "async_method"),
},
preexisting_objects=None,
)
tree.visit(collector)
# Should collect both sync and async functions
assert len(collector.modified_functions) == 4
assert (None, "sync_function") in collector.modified_functions
assert (None, "async_function") in collector.modified_functions
assert ("TestClass", "sync_method") in collector.modified_functions
assert ("TestClass", "async_method") in collector.modified_functions
def test_optim_function_collector_new_async_functions():
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
import libcst as cst
source_code = """
def existing_function():
return "existing"
async def new_async_function():
return "new_async"
def new_sync_function():
return "new_sync"
class ExistingClass:
async def new_class_async_method(self):
return "new_class_async"
"""
# Only existing_function is in preexisting objects
preexisting_objects = {("existing_function", ())}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=set(), # Not looking for specific functions
preexisting_objects=preexisting_objects,
)
tree.visit(collector)
# Should identify new functions (both sync and async)
assert len(collector.new_functions) == 2
function_names = [func.name.value for func in collector.new_functions]
assert "new_async_function" in function_names
assert "new_sync_function" in function_names
# Should identify new class methods
assert "ExistingClass" in collector.new_class_functions
assert len(collector.new_class_functions["ExistingClass"]) == 1
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
def test_optim_function_collector_mixed_scenarios():
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
import libcst as cst
source_code = """
# Global functions
def global_sync():
pass
async def global_async():
pass
class ParentClass:
def __init__(self):
pass
def sync_method(self):
pass
async def async_method(self):
pass
class ChildClass:
async def child_async_method(self):
pass
def child_sync_method(self):
pass
"""
# Looking for specific functions
function_names = {
(None, "global_sync"),
(None, "global_async"),
("ParentClass", "sync_method"),
("ParentClass", "async_method"),
("ChildClass", "child_async_method"),
}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(function_names=function_names, preexisting_objects=None)
tree.visit(collector)
# Should collect all specified functions (mix of sync and async)
assert len(collector.modified_functions) == 5
assert (None, "global_sync") in collector.modified_functions
assert (None, "global_async") in collector.modified_functions
assert ("ParentClass", "sync_method") in collector.modified_functions
assert ("ParentClass", "async_method") in collector.modified_functions
assert ("ChildClass", "child_async_method") in collector.modified_functions
# Should collect __init__ method
assert "ParentClass" in collector.modified_init_functions
def test_is_zero_diff_async_sleep():
original_code = """
import time

View file

@ -19,8 +19,8 @@ from codeflash.code_utils.code_utils import (
path_belongs_to_site_packages,
validate_python_code,
)
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests
from codeflash.languages.python.static_analysis.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
from codeflash.models.models import CodeStringsMarkdown
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path
@ -36,7 +36,7 @@ def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]:
@pytest.fixture
def mock_get_run_tmp_file() -> Generator[MagicMock, None, None]:
with patch("codeflash.code_utils.coverage_utils.get_run_tmp_file") as mock:
with patch("codeflash.languages.python.static_analysis.coverage_utils.get_run_tmp_file") as mock:
yield mock

View file

@ -151,10 +151,9 @@ def test_class_method_dependencies() -> None:
# The code_context above should have the topologicalSortUtil function in it
assert len(code_context.helper_functions) == 1
assert (
code_context.helper_functions[0].jedi_definition.full_name
== "test_function_dependencies.Graph.topologicalSortUtil"
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
)
assert code_context.helper_functions[0].jedi_definition.name == "topologicalSortUtil"
assert code_context.helper_functions[0].only_function_name == "topologicalSortUtil"
assert (
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
)

View file

@ -3,7 +3,7 @@ from pathlib import Path
import pytest
from codeflash.code_utils.code_extractor import get_code
from codeflash.languages.python.static_analysis.code_extractor import get_code
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent

View file

@ -2,7 +2,7 @@ import os
from pathlib import Path
from tempfile import TemporaryDirectory
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.function_optimizer import FunctionOptimizer

View file

@ -15,7 +15,7 @@ from codeflash.code_utils.instrument_existing_tests import (
FunctionImportedAsVisitor,
inject_profiling_into_existing_test,
)
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import (
CodeOptimizationContext,

View file

@ -2,10 +2,10 @@
from unittest.mock import patch
from codeflash.code_utils.code_extractor import is_numerical_code
from codeflash.languages.python.static_analysis.code_extractor import is_numerical_code
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestBasicNumpyUsage:
"""Test basic numpy library detection (with numba available)."""
@ -50,7 +50,7 @@ def func(x):
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestNumpySubmodules:
"""Test numpy submodule imports (with numba available)."""
@ -265,7 +265,7 @@ def func(x):
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestScipyUsage:
"""Test SciPy library detection (with numba available)."""
@ -302,7 +302,7 @@ def func(f, x0):
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestMathUsage:
"""Test math standard library detection (with numba available)."""
@ -331,7 +331,7 @@ def calculate(x):
assert is_numerical_code(code, "calculate") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestClassMethods:
"""Test detection in class methods, staticmethods, and classmethods (with numba available)."""
@ -472,7 +472,7 @@ def func():
assert is_numerical_code(code, "func") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestEdgeCases:
"""Test edge cases and special scenarios (with numba available)."""
@ -535,7 +535,7 @@ async def async_process(x):
assert is_numerical_code(code, "async_process") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestStarImports:
"""Test handling of star imports (with numba available).
@ -575,7 +575,7 @@ def func(x):
assert is_numerical_code(code, "func") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestNestedUsage:
"""Test nested numerical library usage patterns (with numba available)."""
@ -618,7 +618,7 @@ def func(x):
assert is_numerical_code(code, "func") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestMultipleLibraries:
"""Test code using multiple numerical libraries (with numba available)."""
@ -643,7 +643,7 @@ def analyze(data):
assert is_numerical_code(code, "analyze") is True
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestQualifiedNames:
"""Test various qualified name patterns (with numba available)."""
@ -689,7 +689,7 @@ class ClassB:
assert is_numerical_code(code, "ClassB.method") is False
@patch("codeflash.code_utils.code_extractor.has_numba", True)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
class TestEmptyFunctionName:
"""Test behavior when function_name is empty/None.
@ -807,7 +807,7 @@ def broken(
assert is_numerical_code(code, "") is False
@patch("codeflash.code_utils.code_extractor.has_numba", False)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", False)
class TestEmptyFunctionNameWithoutNumba:
"""Test empty function_name behavior when numba is NOT available.
@ -886,7 +886,7 @@ from scipy import stats
assert is_numerical_code(code, "") is False
@patch("codeflash.code_utils.code_extractor.has_numba", False)
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", False)
class TestNumbaNotAvailable:
"""Test behavior when numba is NOT available in the environment.

View file

@ -22,7 +22,7 @@ from codeflash.languages.javascript.find_references import (
find_references,
)
from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo
from codeflash.code_utils.code_extractor import _format_references_as_markdown
from codeflash.languages.python.static_analysis.code_extractor import _format_references_as_markdown
from codeflash.models.models import FunctionParent

View file

@ -14,7 +14,7 @@ from pathlib import Path
import pytest
from codeflash.code_utils.code_replacer import replace_function_definitions_for_language
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (

View file

@ -0,0 +1,189 @@
"""Tests for the regex patterns and string matching in parse_test_output.py."""
from codeflash.verification.parse_test_output import (
matches_re_end,
matches_re_start,
parse_test_failures_from_stdout,
)
# --- matches_re_start tests ---
class TestMatchesReStart:
def test_simple_no_class(self) -> None:
s = "!$######tests.test_foo:test_bar:target_func:1:abc######$!\n"
m = matches_re_start.search(s)
assert m is not None
assert m.groups() == ("tests.test_foo", "", "test_bar", "target_func", "1", "abc")
def test_with_class(self) -> None:
s = "!$######tests.test_foo:MyClass.test_bar:target_func:1:abc######$!\n"
m = matches_re_start.search(s)
assert m is not None
assert m.groups() == ("tests.test_foo", "MyClass.", "test_bar", "target_func", "1", "abc")
def test_nested_class(self) -> None:
s = "!$######a.b.c:A.B.test_x:func:3:id123######$!\n"
m = matches_re_start.search(s)
assert m is not None
assert m.groups() == ("a.b.c", "A.B.", "test_x", "func", "3", "id123")
def test_empty_class_and_function(self) -> None:
s = "!$######mod::func:0:iter######$!\n"
m = matches_re_start.search(s)
assert m is not None
assert m.groups() == ("mod", "", "", "func", "0", "iter")
def test_embedded_in_stdout(self) -> None:
s = "some output\n!$######mod:test_fn:f:1:x######$!\nmore output\n"
m = matches_re_start.search(s)
assert m is not None
assert m.groups() == ("mod", "", "test_fn", "f", "1", "x")
def test_multiple_matches(self) -> None:
s = (
"!$######m1:C1.fn1:t1:1:a######$!\n"
"!$######m2:fn2:t2:2:b######$!\n"
)
matches = list(matches_re_start.finditer(s))
assert len(matches) == 2
assert matches[0].groups() == ("m1", "C1.", "fn1", "t1", "1", "a")
assert matches[1].groups() == ("m2", "", "fn2", "t2", "2", "b")
def test_no_match_without_newline(self) -> None:
s = "!$######mod:test_fn:f:1:x######$!"
m = matches_re_start.search(s)
assert m is None
def test_dots_in_module_path(self) -> None:
s = "!$######a.b.c.d.e:test_fn:f:1:x######$!\n"
m = matches_re_start.search(s)
assert m is not None
assert m.group(1) == "a.b.c.d.e"
# --- matches_re_end tests ---
class TestMatchesReEnd:
def test_simple_no_class_with_runtime(self) -> None:
s = "!######tests.test_foo:test_bar:target_func:1:abc:12345######!"
m = matches_re_end.search(s)
assert m is not None
assert m.groups() == ("tests.test_foo", "", "test_bar", "target_func", "1", "abc:12345")
def test_with_class_no_runtime(self) -> None:
s = "!######tests.test_foo:MyClass.test_bar:target_func:1:abc######!"
m = matches_re_end.search(s)
assert m is not None
assert m.groups() == ("tests.test_foo", "MyClass.", "test_bar", "target_func", "1", "abc")
def test_nested_class_with_runtime(self) -> None:
s = "!######mod:A.B.test_x:func:3:id123:99999######!"
m = matches_re_end.search(s)
assert m is not None
assert m.groups() == ("mod", "A.B.", "test_x", "func", "3", "id123:99999")
def test_runtime_colon_preserved_in_group6(self) -> None:
"""Group 6 must capture 'iteration_id:runtime' as a single string (colon included)."""
s = "!######m:fn:f:1:iter42:98765######!"
m = matches_re_end.search(s)
assert m is not None
assert m.group(6) == "iter42:98765"
def test_embedded_in_stdout(self) -> None:
s = "captured output\n!######mod:test_fn:f:1:x:500######!\nmore"
m = matches_re_end.search(s)
assert m is not None
assert m.groups() == ("mod", "", "test_fn", "f", "1", "x:500")
# --- Start/End pairing (simulates parse_test_xml matching logic) ---
class TestStartEndPairing:
def test_paired_markers(self) -> None:
stdout = (
"!$######mod:Class.test_fn:func:1:iter1######$!\n"
"test output here\n"
"!######mod:Class.test_fn:func:1:iter1:54321######!"
)
starts = list(matches_re_start.finditer(stdout))
ends = {}
for match in matches_re_end.finditer(stdout):
groups = match.groups()
g5 = groups[5]
colon_pos = g5.find(":")
if colon_pos != -1:
key = groups[:5] + (g5[:colon_pos],)
else:
key = groups
ends[key] = match
assert len(starts) == 1
assert len(ends) == 1
# Start and end should pair on the first 5 groups + iteration_id
start_groups = starts[0].groups()
assert start_groups in ends
# --- parse_test_failures_from_stdout tests ---
class TestParseTestFailuresHeader:
def test_standard_pytest_header(self) -> None:
stdout = (
"..F.\n"
"=================================== FAILURES ===================================\n"
"_______ test_foo _______\n"
"\n"
" def test_foo():\n"
"> assert False\n"
"E AssertionError\n"
"\n"
"test.py:3: AssertionError\n"
"=========================== short test summary info ============================\n"
"FAILED test.py::test_foo\n"
)
result = parse_test_failures_from_stdout(stdout)
assert "test_foo" in result
def test_minimal_equals(self) -> None:
"""Even a short '= FAILURES =' header should be detected."""
stdout = (
"= FAILURES =\n"
"_______ test_bar _______\n"
"\n"
" assert False\n"
"\n"
"test.py:1: AssertionError\n"
"= short test summary info =\n"
)
result = parse_test_failures_from_stdout(stdout)
assert "test_bar" in result
def test_no_failures_section(self) -> None:
stdout = "....\n4 passed in 0.1s\n"
result = parse_test_failures_from_stdout(stdout)
assert result == {}
def test_word_failures_without_equals_is_not_matched(self) -> None:
"""'FAILURES' without surrounding '=' signs should not trigger the header detection."""
stdout = (
"FAILURES detected in module\n"
"_______ test_baz _______\n"
"\n"
" assert False\n"
)
result = parse_test_failures_from_stdout(stdout)
assert result == {}
def test_failures_in_test_output_not_matched(self) -> None:
"""A test printing 'FAILURES' (no = signs) should not trigger header detection."""
stdout = (
"Testing FAILURES handling\n"
"All good\n"
)
result = parse_test_failures_from_stdout(stdout)
assert result == {}

View file

@ -0,0 +1,475 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
if TYPE_CHECKING:
from pathlib import Path
from codeflash.languages.base import IndexResult
from codeflash.languages.python.reference_graph import ReferenceGraph
@pytest.fixture
def project(tmp_path: Path) -> Path:
project_root = tmp_path / "project"
project_root.mkdir()
return project_root
@pytest.fixture
def db_path(tmp_path: Path) -> Path:
return tmp_path / "cache.db"
def write_file(project: Path, name: str, content: str) -> Path:
fp = project / name
fp.write_text(content, encoding="utf-8")
return fp
# ---------------------------------------------------------------------------
# Unit tests
# ---------------------------------------------------------------------------
def test_simple_function_call(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def helper():
return 1
def caller():
return helper()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
callee_qns = {fs.qualified_name for fs in result_list}
assert "helper" in callee_qns
finally:
cg.close()
def test_cross_file_call(project: Path, db_path: Path) -> None:
write_file(
project,
"utils.py",
"""\
def utility():
return 42
""",
)
write_file(
project,
"main.py",
"""\
from utils import utility
def caller():
return utility()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "main.py": {"caller"}})
callee_qns = {fs.qualified_name for fs in result_list}
assert "utility" in callee_qns
# Should be in the utils.py file
callee_files = {fs.file_path.resolve() for fs in result_list if fs.qualified_name == "utility"}
assert (project / "utils.py").resolve() in callee_files
finally:
cg.close()
def test_class_instantiation(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
class MyClass:
def __init__(self):
pass
def caller():
obj = MyClass()
return obj
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
callee_types = {fs.definition_type for fs in result_list}
assert "class" in callee_types
finally:
cg.close()
def test_nested_function_excluded(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def caller():
def inner():
return 1
return inner()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
assert len(result_list) == 0
finally:
cg.close()
def test_module_level_not_tracked(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def helper():
return 1
x = helper()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
# Module level calls have no enclosing function, so no edges
_, result_list = cg.get_callees({project / "mod.py": {"helper"}})
# helper itself doesn't call anything
assert len(result_list) == 0
finally:
cg.close()
def test_site_packages_excluded(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
import os
def caller():
return os.path.join("a", "b")
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
# os.path.join is stdlib, should not appear
assert len(result_list) == 0
finally:
cg.close()
def test_empty_file(project: Path, db_path: Path) -> None:
write_file(project, "mod.py", "")
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": set()})
assert len(result_list) == 0
finally:
cg.close()
def test_syntax_error_file(project: Path, db_path: Path) -> None:
write_file(project, "mod.py", "def broken(\n")
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": {"broken"}})
assert len(result_list) == 0
finally:
cg.close()
# ---------------------------------------------------------------------------
# Caching tests
# ---------------------------------------------------------------------------
def test_caching_no_reindex(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def helper():
return 1
def caller():
return helper()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
cg.get_callees({project / "mod.py": {"caller"}})
# Second call should use in-memory cache (hash unchanged)
resolved = str((project / "mod.py").resolve())
assert resolved in cg.indexed_file_hashes
old_hash = cg.indexed_file_hashes[resolved]
cg.get_callees({project / "mod.py": {"caller"}})
assert cg.indexed_file_hashes[resolved] == old_hash
finally:
cg.close()
def test_incremental_update_on_change(project: Path, db_path: Path) -> None:
fp = write_file(
project,
"mod.py",
"""\
def helper():
return 1
def caller():
return helper()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
assert any(fs.qualified_name == "helper" for fs in result_list)
# Modify the file — caller no longer calls helper
fp.write_text(
"""\
def helper():
return 1
def new_helper():
return 2
def caller():
return new_helper()
""",
encoding="utf-8",
)
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
callee_qns = {fs.qualified_name for fs in result_list}
assert "new_helper" in callee_qns
finally:
cg.close()
def test_persistence_across_sessions(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def helper():
return 1
def caller():
return helper()
""",
)
# First session: index the file
cg1 = ReferenceGraph(project, db_path=db_path)
try:
_, result_list = cg1.get_callees({project / "mod.py": {"caller"}})
assert any(fs.qualified_name == "helper" for fs in result_list)
finally:
cg1.close()
# Second session: should read from DB without re-indexing
cg2 = ReferenceGraph(project, db_path=db_path)
try:
assert len(cg2.indexed_file_hashes) == 0 # in-memory cache is empty
_, result_list = cg2.get_callees({project / "mod.py": {"caller"}})
assert any(fs.qualified_name == "helper" for fs in result_list)
finally:
cg2.close()
def test_build_index_with_progress(project: Path, db_path: Path) -> None:
write_file(
project,
"a.py",
"""\
def helper_a():
return 1
def caller_a():
return helper_a()
""",
)
write_file(
project,
"b.py",
"""\
from a import helper_a
def caller_b():
return helper_a()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
progress_calls: list[IndexResult] = []
files = [project / "a.py", project / "b.py"]
cg.build_index(files, on_progress=progress_calls.append)
# Callback fired once per file
assert len(progress_calls) == 2
# Verify IndexResult fields for freshly indexed files
for result in progress_calls:
assert isinstance(result, IndexResult)
assert not result.error
assert not result.cached
assert result.num_edges > 0
assert len(result.edges) == result.num_edges
assert result.cross_file_edges >= 0
# Files are now indexed — get_callees should return correct results
_, result_list = cg.get_callees({project / "a.py": {"caller_a"}})
callee_qns = {fs.qualified_name for fs in result_list}
assert "helper_a" in callee_qns
finally:
cg.close()
def test_build_index_cached_results(project: Path, db_path: Path) -> None:
write_file(
project,
"a.py",
"""\
def helper_a():
return 1
def caller_a():
return helper_a()
""",
)
write_file(
project,
"b.py",
"""\
from a import helper_a
def caller_b():
return helper_a()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
files = [project / "a.py", project / "b.py"]
# First pass — fresh indexing
cg.build_index(files)
# Second pass — should all be cached
cached_results: list[IndexResult] = []
cg.build_index(files, on_progress=cached_results.append)
assert len(cached_results) == 2
for result in cached_results:
assert result.cached
assert not result.error
assert result.num_edges == 0
assert result.edges == ()
assert result.cross_file_edges == 0
finally:
cg.close()
def test_cross_file_edges_tracked(project: Path, db_path: Path) -> None:
write_file(
project,
"utils.py",
"""\
def utility():
return 42
""",
)
write_file(
project,
"main.py",
"""\
from utils import utility
def caller():
return utility()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
progress_calls: list[IndexResult] = []
cg.build_index([project / "utils.py", project / "main.py"], on_progress=progress_calls.append)
# main.py should have cross-file edges (calls into utils.py)
main_result = next(r for r in progress_calls if r.file_path.name == "main.py")
assert main_result.cross_file_edges > 0
# At least one edge tuple should have is_cross_file=True
assert any(is_cross_file for _, _, is_cross_file in main_result.edges)
finally:
cg.close()
def test_count_callees_per_function(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def helper_a():
return 1
def helper_b():
return 2
def caller_one():
return helper_a() + helper_b()
def caller_two():
return helper_a()
def leaf():
return 42
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
cg.build_index([project / "mod.py"])
mod_path = project / "mod.py"
counts = cg.count_callees_per_function({mod_path: {"caller_one", "caller_two", "leaf"}})
assert counts[(mod_path, "caller_one")] == 2
assert counts[(mod_path, "caller_two")] == 1
assert counts[(mod_path, "leaf")] == 0
finally:
cg.close()
def test_same_file_edges_not_cross_file(project: Path, db_path: Path) -> None:
write_file(
project,
"mod.py",
"""\
def helper():
return 1
def caller():
return helper()
""",
)
cg = ReferenceGraph(project, db_path=db_path)
try:
progress_calls: list[IndexResult] = []
cg.build_index([project / "mod.py"], on_progress=progress_calls.append)
assert len(progress_calls) == 1
result = progress_calls[0]
assert result.cross_file_edges == 0
# All edges should have is_cross_file=False
assert all(not is_cross_file for _, _, is_cross_file in result.edges)
finally:
cg.close()

View file

@ -2,7 +2,7 @@ from pathlib import Path
import pytest
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
from codeflash.languages.python.static_analysis.edit_generated_tests import remove_functions_from_generated_tests
from codeflash.models.models import GeneratedTests, GeneratedTestsList

View file

@ -1,7 +1,7 @@
import ast
from pathlib import Path
from codeflash.code_utils.static_analysis import (
from codeflash.languages.python.static_analysis.static_analysis import (
FunctionKind,
ImportedInternalModuleAnalysis,
analyze_imported_modules,
@ -23,10 +23,10 @@ from pathlib import *
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis
from codeflash.languages.python.static_analysis.static_analysis import ImportedInternalModuleAnalysis
def a_function():
from codeflash.code_utils.static_analysis import analyze_imported_modules
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
from returns.result import Failure, Success
pass
"""
@ -37,8 +37,8 @@ def a_function():
expected_imported_module_analysis = [
ImportedInternalModuleAnalysis(
name="static_analysis",
full_name="codeflash.code_utils.static_analysis",
file_path=project_root / Path("codeflash/code_utils/static_analysis.py"),
full_name="codeflash.languages.python.static_analysis.static_analysis",
file_path=project_root / Path("codeflash/languages/python/static_analysis/static_analysis.py"),
),
ImportedInternalModuleAnalysis(
name="mymodule", full_name="tests.mymodule", file_path=project_root / Path("tests/mymodule.py")

View file

@ -918,7 +918,7 @@ class OuterClass:
"only_function_name": "global_helper_1",
"fully_qualified_name": "main.global_helper_1",
"file_path": main_file,
"jedi_definition": type("MockJedi", (), {"type": "function"})(),
"definition_type": "function",
},
)(),
type(
@ -929,7 +929,7 @@ class OuterClass:
"only_function_name": "global_helper_2",
"fully_qualified_name": "main.global_helper_2",
"file_path": main_file,
"jedi_definition": type("MockJedi", (), {"type": "function"})(),
"definition_type": "function",
},
)(),
]

View file

@ -1,108 +0,0 @@
# AI Service
How codeflash communicates with the AI optimization backend.
## `AiServiceClient` (`api/aiservice.py`)
The client connects to the AI service at `https://app.codeflash.ai` (or `http://localhost:8000` when `CODEFLASH_AIS_SERVER=local`).
Authentication uses Bearer token from `get_codeflash_api_key()`. All requests go through `make_ai_service_request()` which handles JSON serialization via Pydantic encoder.
Timeout: 90s for production, 300s for local.
## Endpoints
### `/ai/optimize` — Generate Candidates
Method: `optimize_code()`
Sends source code + dependency context to generate optimization candidates.
Payload:
- `source_code` — The read-writable code (markdown format)
- `dependency_code` — Read-only context code
- `trace_id` — Unique trace ID for the optimization run
- `language``"python"`, `"javascript"`, or `"typescript"`
- `n_candidates` — Number of candidates to generate (controlled by effort level)
- `is_async` — Whether the function is async
- `is_numerical_code` — Whether the code is numerical (affects optimization strategy)
Returns: `list[OptimizedCandidate]` with `source=OptimizedCandidateSource.OPTIMIZE`
### `/ai/optimize_line_profiler` — Line-Profiler-Guided Candidates
Method: `optimize_python_code_line_profiler()`
Like `/optimize` but includes `line_profiler_results` to guide the LLM toward hot lines.
Returns: candidates with `source=OptimizedCandidateSource.OPTIMIZE_LP`
### `/ai/refine` — Refine Existing Candidate
Method: `refine_code()`
Request type: `AIServiceRefinerRequest`
Sends an existing candidate with runtime data and line profiler results to generate an improved version.
Key fields:
- `original_source_code` / `optimized_source_code` — Before and after
- `original_code_runtime` / `optimized_code_runtime` — Timing data
- `speedup` — Current speedup ratio
- `original_line_profiler_results` / `optimized_line_profiler_results`
Returns: candidates with `source=OptimizedCandidateSource.REFINE` and `parent_id` set to the refined candidate's ID
### `/ai/repair` — Fix Failed Candidate
Method: `repair_code()`
Request type: `AIServiceCodeRepairRequest`
Sends a failed candidate with test diffs showing what went wrong.
Key fields:
- `original_source_code` / `modified_source_code`
- `test_diffs: list[TestDiff]` — Each with `scope` (return_value/stdout/did_pass), original vs candidate values, and test source code
Returns: candidates with `source=OptimizedCandidateSource.REPAIR` and `parent_id` set
### `/ai/adaptive_optimize` — Multi-Candidate Adaptive
Method: `adaptive_optimize()`
Request type: `AIServiceAdaptiveOptimizeRequest`
Sends multiple previous candidates with their speedups for the LLM to learn from and generate better candidates.
Key fields:
- `candidates: list[AdaptiveOptimizedCandidate]` — Previous candidates with source code, explanation, source type, and speedup
Returns: candidates with `source=OptimizedCandidateSource.ADAPTIVE`
### `/ai/rewrite_jit` — JIT Rewrite
Method: `get_jit_rewritten_code()`
Rewrites code to use JIT compilation (e.g., Numba).
Returns: candidates with `source=OptimizedCandidateSource.JIT_REWRITE`
## Candidate Parsing
All endpoints return JSON with an `optimizations` array. Each entry has:
- `source_code` — Markdown-formatted code blocks
- `explanation` — LLM explanation
- `optimization_id` — Unique ID
- `parent_id` — Optional parent reference
- `model` — Which LLM model was used
`_get_valid_candidates()` parses the markdown code via `CodeStringsMarkdown.parse_markdown_code()` and filters out entries with empty code blocks.
## `LocalAiServiceClient`
Used when `CODEFLASH_EXPERIMENT_ID` is set. Mirrors `AiServiceClient` but sends to a separate experimental endpoint for A/B testing optimization strategies.
## LLM Call Sequencing
`AiServiceClient` tracks call sequence via `llm_call_counter` (itertools.count). Each request includes a `call_sequence` number, used by the backend to maintain conversation context across multiple calls for the same function.

View file

@ -1,79 +0,0 @@
# Configuration
Key configuration constants, effort levels, and thresholds.
## Constants (`code_utils/config_consts.py`)
### Test Execution
| Constant | Value | Description |
|----------|-------|-------------|
| `MAX_TEST_RUN_ITERATIONS` | 5 | Maximum test loop iterations |
| `INDIVIDUAL_TESTCASE_TIMEOUT` | 15s | Timeout per individual test case |
| `MAX_FUNCTION_TEST_SECONDS` | 60s | Max total time for function testing |
| `MAX_TEST_FUNCTION_RUNS` | 50 | Max test function executions |
| `MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS` | 100ms | Max cumulative test runtime |
| `TOTAL_LOOPING_TIME` | 10s | Candidate benchmarking budget |
| `MIN_TESTCASE_PASSED_THRESHOLD` | 6 | Minimum test cases that must pass |
### Performance Thresholds
| Constant | Value | Description |
|----------|-------|-------------|
| `MIN_IMPROVEMENT_THRESHOLD` | 0.05 (5%) | Minimum speedup to accept a candidate |
| `MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD` | 0.10 (10%) | Minimum async throughput improvement |
| `MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD` | 0.20 (20%) | Minimum concurrency ratio improvement |
| `COVERAGE_THRESHOLD` | 60.0% | Minimum test coverage |
### Stability Thresholds
| Constant | Value | Description |
|----------|-------|-------------|
| `STABILITY_WINDOW_SIZE` | 0.35 | 35% of total iteration window |
| `STABILITY_CENTER_TOLERANCE` | 0.0025 | ±0.25% around median |
| `STABILITY_SPREAD_TOLERANCE` | 0.0025 | 0.25% window spread |
### Context Limits
| Constant | Value | Description |
|----------|-------|-------------|
| `OPTIMIZATION_CONTEXT_TOKEN_LIMIT` | 16000 | Max tokens for optimization context |
| `TESTGEN_CONTEXT_TOKEN_LIMIT` | 16000 | Max tokens for test generation context |
| `MAX_CONTEXT_LEN_REVIEW` | 1000 | Max context length for optimization review |
### Other
| Constant | Value | Description |
|----------|-------|-------------|
| `MIN_CORRECT_CANDIDATES` | 2 | Min correct candidates before skipping repair |
| `REPEAT_OPTIMIZATION_PROBABILITY` | 0.1 | Probability of re-optimizing a function |
| `DEFAULT_IMPORTANCE_THRESHOLD` | 0.001 | Minimum addressable time to consider a function |
| `CONCURRENCY_FACTOR` | 10 | Number of concurrent executions for concurrency benchmark |
| `REFINED_CANDIDATE_RANKING_WEIGHTS` | (2, 1) | (runtime, diff) weights — runtime 2x more important |
## Effort Levels
`EffortLevel` enum: `LOW`, `MEDIUM`, `HIGH`
Effort controls the number of candidates, repairs, and refinements:
| Key | LOW | MEDIUM | HIGH |
|-----|-----|--------|------|
| `N_OPTIMIZER_CANDIDATES` | 3 | 5 | 6 |
| `N_OPTIMIZER_LP_CANDIDATES` | 4 | 6 | 7 |
| `N_GENERATED_TESTS` | 2 | 2 | 2 |
| `MAX_CODE_REPAIRS_PER_TRACE` | 2 | 3 | 5 |
| `REPAIR_UNMATCHED_PERCENTAGE_LIMIT` | 0.2 | 0.3 | 0.4 |
| `TOP_VALID_CANDIDATES_FOR_REFINEMENT` | 2 | 3 | 4 |
| `ADAPTIVE_OPTIMIZATION_THRESHOLD` | 0 | 0 | 2 |
| `MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE` | 0 | 0 | 4 |
Use `get_effort_value(EffortKeys.KEY, effort_level)` to retrieve values.
## Project Configuration
Configuration is read from `pyproject.toml` under `[tool.codeflash]`. Key settings are auto-detected by `setup/detector.py`:
- `module-root` — Root of the module to optimize
- `tests-root` — Root of test files
- `test-framework` — pytest, unittest, jest, etc.
- `formatter-cmds` — Code formatting commands

View file

@ -1,60 +0,0 @@
# Context Extraction
How codeflash extracts and limits code context for optimization and test generation.
## Overview
Context extraction (`context/code_context_extractor.py`) builds a `CodeOptimizationContext` containing all code needed for the LLM to understand and optimize a function, split into:
- **Read-writable code** (`CodeContextType.READ_WRITABLE`): The function being optimized plus its helper functions — code the LLM is allowed to modify
- **Read-only context** (`CodeContextType.READ_ONLY`): Dependency code for reference — imports, type definitions, base classes
- **Testgen context** (`CodeContextType.TESTGEN`): Context for test generation, may include imported class definitions and external base class inits
- **Hashing context** (`CodeContextType.HASHING`): Used for deduplication of optimization runs
## Token Limits
Both optimization and test generation contexts are token-limited:
- `OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000` tokens
- `TESTGEN_CONTEXT_TOKEN_LIMIT = 16000` tokens
Token counting uses `encoded_tokens_len()` from `code_utils/code_utils.py`. Functions whose context exceeds these limits are skipped.
## Context Building Process
### 1. Helper Discovery
For the target function (`FunctionToOptimize`), the extractor finds:
- **Helpers of the function**: Functions/classes in the same file that the target function calls
- **Helpers of helpers**: Transitive dependencies of the helper functions
These are organized as `dict[Path, set[FunctionSource]]` — mapping file paths to the set of helper functions found in each file.
### 2. Code Extraction
`extract_code_markdown_context_from_files()` builds `CodeStringsMarkdown` from the helper dictionaries. Each file's relevant code is extracted as a `CodeString` with its file path.
### 3. Testgen Context Enrichment
`build_testgen_context()` extends the basic context with:
- Imported class definitions (resolved from imports)
- External base class `__init__` methods
- External class `__init__` methods referenced in the context
### 4. Unused Definition Removal
`detect_unused_helper_functions()` and `remove_unused_definitions_by_function_names()` from `context/unused_definition_remover.py` prune definitions that are not transitively reachable from the target function, reducing token usage.
### 5. Deduplication
The hashing context (`hashing_code_context`) generates a hash (`hashing_code_context_hash`) used to detect when the same function context has already been optimized in a previous run, avoiding redundant work.
## Key Functions
| Function | Location | Purpose |
|----------|----------|---------|
| `build_testgen_context()` | `context/code_context_extractor.py` | Build enriched testgen context |
| `extract_code_markdown_context_from_files()` | `context/code_context_extractor.py` | Convert helper dicts to `CodeStringsMarkdown` |
| `detect_unused_helper_functions()` | `context/unused_definition_remover.py` | Find unused definitions |
| `remove_unused_definitions_by_function_names()` | `context/unused_definition_remover.py` | Remove unused definitions |
| `collect_top_level_defs_with_usages()` | `context/unused_definition_remover.py` | Analyze definition usage |
| `encoded_tokens_len()` | `code_utils/code_utils.py` | Count tokens in code |

View file

@ -1,153 +0,0 @@
# Domain Types
Core data types used throughout the codeflash optimization pipeline.
## Function Representation
### `FunctionToOptimize` (`models/function_types.py`)
The canonical dataclass representing a function candidate for optimization. Works across Python, JavaScript, and TypeScript.
Key fields:
- `function_name: str` — The function name
- `file_path: Path` — Absolute file path where the function is located
- `parents: list[FunctionParent]` — Parent scopes (classes/functions), each with `name` and `type`
- `starting_line / ending_line: Optional[int]` — Line range (1-indexed)
- `is_async: bool` — Whether the function is async
- `is_method: bool` — Whether it belongs to a class
- `language: str` — Programming language (default: `"python"`)
Key properties:
- `qualified_name` — Full dotted name including parent classes (e.g., `MyClass.my_method`)
- `top_level_parent_name` — Name of outermost parent, or function name if no parents
- `class_name` — Immediate parent class name, or `None`
### `FunctionParent` (`models/function_types.py`)
Represents a parent scope: `name: str` (e.g., `"MyClass"`) and `type: str` (e.g., `"ClassDef"`).
### `FunctionSource` (`models/models.py`)
Represents a resolved function with source code. Used for helper functions in context extraction.
Fields: `file_path`, `qualified_name`, `fully_qualified_name`, `only_function_name`, `source_code`, `jedi_definition`.
## Code Representation
### `CodeString` (`models/models.py`)
A single code block with validated syntax:
- `code: str` — The source code
- `file_path: Optional[Path]` — Origin file path
- `language: str` — Language for validation (default: `"python"`)
Validates syntax on construction via `model_validator`.
### `CodeStringsMarkdown` (`models/models.py`)
A collection of `CodeString` blocks — the primary format for passing code through the pipeline.
Key properties:
- `.flat` — Combined source code with file-path comment prefixes (e.g., `# file: path/to/file.py`)
- `.markdown` — Markdown-formatted with fenced code blocks: `` ```python:filepath\ncode\n``` ``
- `.file_to_path()` — Dict mapping file path strings to code
Static method:
- `parse_markdown_code(markdown_code, expected_language)` — Parses markdown code blocks back into `CodeStringsMarkdown`
## Optimization Context
### `CodeOptimizationContext` (`models/models.py`)
Holds all code context needed for optimization:
- `read_writable_code: CodeStringsMarkdown` — Code the LLM can modify
- `read_only_context_code: str` — Reference-only dependency code
- `testgen_context: CodeStringsMarkdown` — Context for test generation
- `hashing_code_context: str` / `hashing_code_context_hash: str` — For deduplication
- `helper_functions: list[FunctionSource]` — Helper functions in the writable code
- `preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]` — Objects that already exist in the code
### `CodeContextType` enum (`models/models.py`)
Defines context categories: `READ_WRITABLE`, `READ_ONLY`, `TESTGEN`, `HASHING`.
## Candidates
### `OptimizedCandidate` (`models/models.py`)
A generated code variant:
- `source_code: CodeStringsMarkdown` — The optimized code
- `explanation: str` — LLM explanation of the optimization
- `optimization_id: str` — Unique identifier
- `source: OptimizedCandidateSource` — How it was generated
- `parent_id: str | None` — ID of parent candidate (for refinements/repairs)
- `model: str | None` — Which LLM model generated it
### `OptimizedCandidateSource` enum (`models/models.py`)
How a candidate was generated: `OPTIMIZE`, `OPTIMIZE_LP` (line profiler), `REFINE`, `REPAIR`, `ADAPTIVE`, `JIT_REWRITE`.
### `CandidateEvaluationContext` (`models/models.py`)
Tracks state during candidate evaluation:
- `speedup_ratios` / `optimized_runtimes` / `is_correct` — Per-candidate results
- `ast_code_to_id` — Deduplication map (normalized AST → first seen candidate)
- `valid_optimizations` — Candidates that passed all checks
Key methods: `record_failed_candidate()`, `record_successful_candidate()`, `handle_duplicate_candidate()`, `register_new_candidate()`.
## Baseline & Results
### `OriginalCodeBaseline` (`models/models.py`)
Baseline measurements for the original code:
- `behavior_test_results: TestResults` / `benchmarking_test_results: TestResults`
- `line_profile_results: dict`
- `runtime: int` — Total runtime in nanoseconds
- `coverage_results: Optional[CoverageData]`
### `BestOptimization` (`models/models.py`)
The winning candidate after evaluation:
- `candidate: OptimizedCandidate`
- `helper_functions: list[FunctionSource]`
- `code_context: CodeOptimizationContext`
- `runtime: int`
- `winning_behavior_test_results` / `winning_benchmarking_test_results: TestResults`
## Test Types
### `TestType` enum (`models/test_type.py`)
- `EXISTING_UNIT_TEST` (1) — Pre-existing tests from the codebase
- `INSPIRED_REGRESSION` (2) — Tests inspired by existing tests
- `GENERATED_REGRESSION` (3) — AI-generated regression tests
- `REPLAY_TEST` (4) — Tests from recorded benchmark data
- `CONCOLIC_COVERAGE_TEST` (5) — Coverage-guided tests
- `INIT_STATE_TEST` (6) — Class init state verification
### `TestFile` / `TestFiles` (`models/models.py`)
`TestFile` represents a single test file with `instrumented_behavior_file_path`, optional `benchmarking_file_path`, `original_file_path`, `test_type`, and `tests_in_file`.
`TestFiles` is a collection with lookup methods: `get_by_type()`, `get_by_original_file_path()`, `get_test_type_by_instrumented_file_path()`.
### `TestResults` (`models/models.py`)
Collection of `FunctionTestInvocation` results with indexed lookup. Key methods:
- `add(invocation)` — Deduplicated insert
- `total_passed_runtime()` — Sum of minimum runtimes per test case (nanoseconds)
- `number_of_loops()` — Max loop index across all results
- `usable_runtime_data_by_test_case()` — Dict of invocation ID → list of runtimes
## Result Type
### `Result[L, R]` / `Success` / `Failure` (`either.py`)
Functional error handling type:
- `Success(value)` — Wraps a successful result
- `Failure(error)` — Wraps an error
- `result.is_successful()` / `result.is_failure()` — Check type
- `result.unwrap()` — Get success value (raises if Failure)
- `result.failure()` — Get failure value (raises if Success)
- `is_successful(result)` — Module-level helper function

View file

@ -1,41 +0,0 @@
# Codeflash Internal Documentation
CodeFlash is an AI-powered Python code optimizer that automatically improves code performance while maintaining correctness. It uses LLMs to generate optimization candidates, verifies correctness through test execution, and benchmarks performance improvements.
## Pipeline Overview
```
Discovery → Ranking → Context Extraction → Test Gen + Optimization → Baseline → Candidate Evaluation → PR
```
1. **Discovery** (`discovery/`): Find optimizable functions across the codebase using `FunctionVisitor`
2. **Ranking** (`benchmarking/function_ranker.py`): Rank functions by addressable time using trace data
3. **Context** (`context/`): Extract code dependencies — split into read-writable (modifiable) and read-only (reference)
4. **Optimization** (`optimization/`, `api/`): Generate candidates via AI service, runs concurrently with test generation
5. **Verification** (`verification/`): Run candidates against tests via custom pytest plugin, compare outputs
6. **Benchmarking** (`benchmarking/`): Measure performance, select best candidate by speedup
7. **Result** (`result/`, `github/`): Create PR with winning optimization
## Key Entry Points
| Task | File |
|------|------|
| CLI arguments & commands | `cli_cmds/cli.py` |
| Optimization orchestration | `optimization/optimizer.py``Optimizer.run()` |
| Per-function optimization | `optimization/function_optimizer.py``FunctionOptimizer` |
| Function discovery | `discovery/functions_to_optimize.py` |
| Context extraction | `context/code_context_extractor.py` |
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
| Performance ranking | `benchmarking/function_ranker.py` |
| Domain types | `models/models.py`, `models/function_types.py` |
| AI service | `api/aiservice.py``AiServiceClient` |
| Configuration | `code_utils/config_consts.py` |
## Documentation Pages
- [Domain Types](domain-types.md) — Core data types and their relationships
- [Optimization Pipeline](optimization-pipeline.md) — Step-by-step data flow through the pipeline
- [Context Extraction](context-extraction.md) — How code context is extracted and token-limited
- [Verification](verification.md) — Test execution, pytest plugin, deterministic patches
- [AI Service](ai-service.md) — AI service client endpoints and request types
- [Configuration](configuration.md) — Config schema, effort levels, thresholds

View file

@ -1,84 +0,0 @@
# Optimization Pipeline
Step-by-step data flow from function discovery to PR creation.
## 1. Entry Point: `Optimizer.run()` (`optimization/optimizer.py`)
The `Optimizer` class is initialized with CLI args and creates:
- `TestConfig` with test roots, project root, pytest command
- `AiServiceClient` for AI service communication
- Optional `LocalAiServiceClient` for experiments
`run()` orchestrates the full pipeline: discovers functions, optionally ranks them, then optimizes each in turn.
## 2. Function Discovery (`discovery/functions_to_optimize.py`)
`FunctionVisitor` traverses source files to find optimizable functions, producing `FunctionToOptimize` instances. Filters include:
- Skipping functions that are too small or trivial
- Skipping previously optimized functions (via `was_function_previously_optimized()`)
- Applying user-configured include/exclude patterns
## 3. Function Ranking (`benchmarking/function_ranker.py`)
When trace data is available, `FunctionRanker` ranks functions by **addressable time** — the time a function spends that could be optimized (own time + callee time / call count). Functions below `DEFAULT_IMPORTANCE_THRESHOLD=0.001` are skipped.
## 4. Per-Function Optimization: `FunctionOptimizer` (`optimization/function_optimizer.py`)
For each function, `FunctionOptimizer.optimize_function()` runs the full optimization loop:
### 4a. Context Extraction (`context/code_context_extractor.py`)
Extracts `CodeOptimizationContext` containing:
- `read_writable_code` — Code the LLM can modify (the function + helpers)
- `read_only_context_code` — Dependency code for reference only
- `testgen_context` — Context for test generation (may include imported class definitions)
Token limits are enforced: `OPTIMIZATION_CONTEXT_TOKEN_LIMIT=16000` and `TESTGEN_CONTEXT_TOKEN_LIMIT=16000`. Functions exceeding these are rejected.
### 4b. Concurrent Test Generation + LLM Optimization
These run in parallel using `concurrent.futures`:
- **Test generation**: Generates regression tests from the function context
- **LLM optimization**: Sends `read_writable_code.markdown` + `read_only_context_code` to the AI service
The number of candidates depends on effort level (see Configuration docs).
### 4c. Candidate Evaluation
For each `OptimizedCandidate`:
1. **Deduplication**: Normalize code AST and check against `CandidateEvaluationContext.ast_code_to_id`. If duplicate, copy results from previous evaluation.
2. **Code replacement**: Replace the original function with the candidate using `replace_function_definitions_in_module()`.
3. **Behavioral testing**: Run instrumented tests in subprocess. The custom pytest plugin applies deterministic patches. Compare return values, stdout, and pass/fail status against the original baseline.
4. **Benchmarking**: If behavior matches, run performance tests with looping (`TOTAL_LOOPING_TIME=10s`). Calculate speedup ratio.
5. **Validation**: Candidate must beat `MIN_IMPROVEMENT_THRESHOLD=0.05` (5% speedup) and pass stability checks.
### 4d. Refinement & Repair
- **Repair**: If fewer than `MIN_CORRECT_CANDIDATES=2` pass, failed candidates can be repaired via `AIServiceCodeRepairRequest` (sends test diffs to LLM).
- **Refinement**: Top valid candidates are refined via `AIServiceRefinerRequest` (sends runtime data, line profiler results).
- **Adaptive**: At HIGH effort, additional adaptive optimization rounds via `AIServiceAdaptiveOptimizeRequest`.
### 4e. Best Candidate Selection
The winning candidate is selected by:
1. Highest speedup ratio
2. For tied speedups, shortest diff length from original
3. Refinement candidates use weighted ranking: `(2 * runtime_rank + 1 * diff_rank)`
Result is a `BestOptimization` with the candidate, context, test results, and runtime.
## 5. PR Creation (`github/`)
If a winning candidate is found, a PR is created with:
- The optimized code diff
- Performance benchmark details
- Explanation from the LLM
## Worktree Mode
When `--worktree` is enabled, optimization runs in an isolated git worktree (`code_utils/git_worktree_utils.py`). This allows parallel optimization without affecting the working tree. Changes are captured as patch files.

View file

@ -1,93 +0,0 @@
# Verification
How codeflash verifies candidate correctness and measures performance.
## Test Execution Architecture
Tests are executed in a **subprocess** to isolate the test environment from the main codeflash process. The test runner (`verification/test_runner.py`) invokes pytest (or Jest for JS/TS) with specific plugin configurations.
### Plugin Blocklists
- **Behavioral tests**: Block `benchmark`, `codspeed`, `xdist`, `sugar`
- **Benchmarking tests**: Block `codspeed`, `cov`, `benchmark`, `profiling`, `xdist`, `sugar`
These are defined as `BEHAVIORAL_BLOCKLISTED_PLUGINS` and `BENCHMARKING_BLOCKLISTED_PLUGINS` in `verification/test_runner.py`.
## Custom Pytest Plugin (`verification/pytest_plugin.py`)
The plugin is loaded into the test subprocess and provides:
### Deterministic Patches
`_apply_deterministic_patches()` replaces non-deterministic functions with fixed values to ensure reproducible test output:
| Module | Function | Fixed Value |
|--------|----------|-------------|
| `time` | `time()` | `1761717605.108106` |
| `time` | `perf_counter()` | Incrementing by 1ms per call |
| `datetime` | `datetime.now()` | `2021-01-01 02:05:10 UTC` |
| `datetime` | `datetime.utcnow()` | `2021-01-01 02:05:10 UTC` |
| `uuid` | `uuid4()` / `uuid1()` | `12345678-1234-5678-9abc-123456789012` |
| `random` | `random()` | `0.123456789` (seeded with 42) |
| `os` | `urandom(n)` | `b"\x42" * n` |
| `numpy.random` | seed | `42` |
Patches call the original function first to maintain performance characteristics (same call overhead).
### Timing Markers
Test results include timing markers in stdout: `!######<id>:<duration_ns>######!`
The pattern `_TIMING_MARKER_PATTERN` extracts timing data for calculating function utilization fraction.
### Loop Stability
Performance benchmarking uses configurable stability thresholds:
- `STABILITY_WINDOW_SIZE = 0.35` (35% of total iterations)
- `STABILITY_CENTER_TOLERANCE = 0.0025` (±0.25% around median)
- `STABILITY_SPREAD_TOLERANCE = 0.0025` (0.25% window spread)
### Memory Limits (Linux)
On Linux, the plugin sets `RLIMIT_AS` to 85% of total system memory (RAM + swap) to prevent OOM kills.
## Test Result Processing
### `TestResults` (`models/models.py`)
Collects `FunctionTestInvocation` results with:
- Deduplicated insertion via `unique_invocation_loop_id`
- `total_passed_runtime()` — Sum of minimum runtimes per test case (nanoseconds)
- `number_of_loops()` — Max loop index
- `usable_runtime_data_by_test_case()` — Grouped timing data
### `FunctionTestInvocation`
Each invocation records:
- `loop_index` — Iteration number (starts at 1)
- `id: InvocationId` — Fully qualified test identifier
- `did_pass: bool` — Pass/fail status
- `runtime: Optional[int]` — Time in nanoseconds
- `return_value: Optional[object]` — Captured return value
- `test_type: TestType` — Which test category
### Behavioral vs Performance Testing
1. **Behavioral**: Runs with `TestingMode.BEHAVIOR`. Compares return values and stdout between original and candidate. Any difference = candidate rejected.
2. **Performance**: Runs with `TestingMode.PERFORMANCE`. Loops for `TOTAL_LOOPING_TIME=10s` to get stable timing. Calculates speedup ratio.
3. **Line Profile**: Runs with `TestingMode.LINE_PROFILE`. Collects per-line timing data for refinement.
## Test Types
| TestType | Value | Description |
|----------|-------|-------------|
| `EXISTING_UNIT_TEST` | 1 | Pre-existing tests from the codebase |
| `INSPIRED_REGRESSION` | 2 | Tests inspired by existing tests |
| `GENERATED_REGRESSION` | 3 | AI-generated regression tests |
| `REPLAY_TEST` | 4 | Tests from recorded benchmark data |
| `CONCOLIC_COVERAGE_TEST` | 5 | Coverage-guided tests |
| `INIT_STATE_TEST` | 6 | Class init state verification |
## Coverage
Coverage is measured via `CoverageData` with a threshold of `COVERAGE_THRESHOLD=60.0%`. Low coverage may affect confidence in the optimization's correctness.

View file

@ -1,118 +0,0 @@
{
"package_name": "codeflash-docs",
"total_capabilities": 16,
"capabilities": [
{
"id": 0,
"name": "pipeline-stage-ordering",
"description": "Know the correct ordering of codeflash pipeline stages: Discovery → Ranking → Context Extraction → Test Gen + Optimization (concurrent) → Baseline → Candidate Evaluation → PR",
"complexity": "basic",
"api_elements": ["Optimizer.run()", "FunctionOptimizer.optimize_function()"]
},
{
"id": 1,
"name": "function-to-optimize-fields",
"description": "Know FunctionToOptimize key fields (function_name, file_path, parents, starting_line/ending_line, is_async, is_method, language) and properties (qualified_name, top_level_parent_name, class_name)",
"complexity": "intermediate",
"api_elements": ["FunctionToOptimize", "FunctionParent", "models/function_types.py"]
},
{
"id": 2,
"name": "code-strings-markdown-format",
"description": "Know that code is serialized as markdown fenced blocks with language:filepath syntax (```python:filepath\\ncode\\n```) and parsed via CodeStringsMarkdown.parse_markdown_code()",
"complexity": "intermediate",
"api_elements": ["CodeStringsMarkdown", "CodeString", ".markdown", ".flat", "parse_markdown_code()"]
},
{
"id": 3,
"name": "read-writable-vs-read-only",
"description": "Distinguish read_writable_code (LLM can modify) from read_only_context_code (reference only) in CodeOptimizationContext",
"complexity": "basic",
"api_elements": ["CodeOptimizationContext", "read_writable_code", "read_only_context_code"]
},
{
"id": 4,
"name": "candidate-source-types",
"description": "Know OptimizedCandidateSource variants: OPTIMIZE, OPTIMIZE_LP, REFINE, REPAIR, ADAPTIVE, JIT_REWRITE and when each is used",
"complexity": "intermediate",
"api_elements": ["OptimizedCandidateSource", "OptimizedCandidate"]
},
{
"id": 5,
"name": "candidate-forest-dag",
"description": "Know that candidates form a forest/DAG via parent_id references where refinements and repairs build on previous candidates",
"complexity": "intermediate",
"api_elements": ["parent_id", "OptimizedCandidate", "CandidateForest"]
},
{
"id": 6,
"name": "concurrent-testgen-optimization",
"description": "Know that test generation and LLM optimization run concurrently using concurrent.futures, not sequentially",
"complexity": "intermediate",
"api_elements": ["concurrent.futures", "FunctionOptimizer.optimize_function()"]
},
{
"id": 7,
"name": "deterministic-patch-values",
"description": "Know the specific fixed values used by deterministic patches: time=1761717605.108106, datetime=2021-01-01 02:05:10 UTC, uuid=12345678-1234-5678-9abc-123456789012, random seeded with 42",
"complexity": "advanced",
"api_elements": ["_apply_deterministic_patches()", "pytest_plugin.py"]
},
{
"id": 8,
"name": "test-type-enum",
"description": "Know the 6 TestType variants: EXISTING_UNIT_TEST, INSPIRED_REGRESSION, GENERATED_REGRESSION, REPLAY_TEST, CONCOLIC_COVERAGE_TEST, INIT_STATE_TEST",
"complexity": "basic",
"api_elements": ["TestType", "models/test_type.py"]
},
{
"id": 9,
"name": "ai-service-endpoints",
"description": "Know the AI service endpoints: /ai/optimize, /ai/optimize_line_profiler, /ai/refine, /ai/repair, /ai/adaptive_optimize, /ai/rewrite_jit",
"complexity": "intermediate",
"api_elements": ["AiServiceClient", "api/aiservice.py"]
},
{
"id": 10,
"name": "repair-request-structure",
"description": "Know that AIServiceCodeRepairRequest includes TestDiff objects with scope (RETURN_VALUE/STDOUT/DID_PASS), original vs candidate values, and test source code",
"complexity": "advanced",
"api_elements": ["AIServiceCodeRepairRequest", "TestDiff", "TestDiffScope"]
},
{
"id": 11,
"name": "effort-level-values",
"description": "Know specific effort level values: LOW gets 3 candidates, MEDIUM gets 5, HIGH gets 6 (N_OPTIMIZER_CANDIDATES)",
"complexity": "intermediate",
"api_elements": ["EffortLevel", "N_OPTIMIZER_CANDIDATES", "EFFORT_VALUES"]
},
{
"id": 12,
"name": "context-token-limits",
"description": "Know OPTIMIZATION_CONTEXT_TOKEN_LIMIT=16000 and TESTGEN_CONTEXT_TOKEN_LIMIT=16000 and that encoded_tokens_len() is used for counting",
"complexity": "basic",
"api_elements": ["OPTIMIZATION_CONTEXT_TOKEN_LIMIT", "TESTGEN_CONTEXT_TOKEN_LIMIT", "encoded_tokens_len()"]
},
{
"id": 13,
"name": "best-candidate-selection",
"description": "Know the selection criteria: highest speedup, then shortest diff for ties, and refinement weighted ranking (2*runtime + 1*diff)",
"complexity": "advanced",
"api_elements": ["BestOptimization", "REFINED_CANDIDATE_RANKING_WEIGHTS"]
},
{
"id": 14,
"name": "plugin-blocklists",
"description": "Know behavioral test blocklisted plugins (benchmark, codspeed, xdist, sugar) and benchmarking blocklist (adds cov, profiling)",
"complexity": "intermediate",
"api_elements": ["BEHAVIORAL_BLOCKLISTED_PLUGINS", "BENCHMARKING_BLOCKLISTED_PLUGINS"]
},
{
"id": 15,
"name": "result-type-usage",
"description": "Know that Result[L,R] from either.py uses Success(value)/Failure(error) with is_successful() check before unwrap()",
"complexity": "basic",
"api_elements": ["Result", "Success", "Failure", "is_successful", "either.py"]
}
]
}

View file

@ -1 +0,0 @@
Code serialization format and context splitting

View file

@ -1,21 +0,0 @@
{
"context": "Tests whether the agent knows the CodeStringsMarkdown serialization format and the distinction between read-writable and read-only code context in the codeflash pipeline.",
"type": "weighted_checklist",
"checklist": [
{
"name": "Markdown code block format",
"description": "Uses the correct fenced code block format with language:filepath syntax (```python:path/to/file.py) when constructing code for the AI service, NOT plain code blocks without file paths",
"max_score": 30
},
{
"name": "Read-writable vs read-only split",
"description": "Correctly separates code into read_writable_code (code the LLM can modify) and read_only_context_code (reference-only dependency code), NOT treating all code as modifiable",
"max_score": 35
},
{
"name": "parse_markdown_code usage",
"description": "Uses CodeStringsMarkdown.parse_markdown_code() to parse AI service responses back into structured code, NOT manual string splitting or regex",
"max_score": 35
}
]
}

View file

@ -1,35 +0,0 @@
# Format Code for AI Service Request
## Context
You are working on the codeflash optimization engine. The AI service accepts optimization requests with source code and dependency context. A function `calculate_total` in `analytics/metrics.py` needs to be optimized. It calls a helper `normalize_values` in the same file (both modifiable), and imports `BaseMetric` from `analytics/base.py` (not modifiable, just for reference).
```python
# analytics/metrics.py
from analytics.base import BaseMetric
def normalize_values(data: list[float]) -> list[float]:
max_val = max(data)
return [x / max_val for x in data]
def calculate_total(metrics: list[BaseMetric]) -> float:
values = [m.value for m in metrics]
normalized = normalize_values(values)
return sum(normalized)
```
```python
# analytics/base.py
class BaseMetric:
def __init__(self, name: str, value: float):
self.name = name
self.value = value
```
## Task
Write a Python function `prepare_optimization_payload` that constructs the code payload for an AI service optimization request for `calculate_total`. It should properly format the source code and dependency code, and include a function to parse the AI service response back into structured code objects.
## Expected Outputs
- A Python file `payload_builder.py` with the payload construction and response parsing logic

View file

@ -1 +0,0 @@
Candidate source types and DAG relationships

View file

@ -1,26 +0,0 @@
{
"context": "Tests whether the agent knows the different OptimizedCandidateSource types and how candidates form a DAG via parent_id references in the codeflash pipeline.",
"type": "weighted_checklist",
"checklist": [
{
"name": "Lists source types",
"description": "Identifies at least 4 of the 6 OptimizedCandidateSource variants: OPTIMIZE, OPTIMIZE_LP, REFINE, REPAIR, ADAPTIVE, JIT_REWRITE",
"max_score": 25
},
{
"name": "Parent ID linkage",
"description": "Explains that REFINE and REPAIR candidates reference their parent via parent_id, creating a DAG/forest structure, NOT independent candidates",
"max_score": 25
},
{
"name": "Refinement uses runtime data",
"description": "States that refinement sends runtime data and line profiler results to the AI service (AIServiceRefinerRequest), NOT just the source code",
"max_score": 25
},
{
"name": "Repair uses test diffs",
"description": "States that repair sends test failure diffs (TestDiff with scope: RETURN_VALUE/STDOUT/DID_PASS) to the AI service, NOT just error messages",
"max_score": 25
}
]
}

View file

@ -1,13 +0,0 @@
# Document the Candidate Lifecycle
## Context
A new engineer is joining the codeflash team and needs to understand how optimization candidates are generated, improved, and related to each other throughout the pipeline. They've asked for a clear explanation of the different ways candidates are produced and how the system iterates on them.
## Task
Write a technical document explaining the full lifecycle of an optimization candidate in codeflash — from initial generation through improvement iterations. Cover all the different ways candidates can be created, what data is sent to the AI service for each type, and how candidates relate to each other structurally.
## Expected Outputs
- A markdown file `candidate-lifecycle.md`

View file

@ -1 +0,0 @@
Deterministic patch values and test execution architecture

View file

@ -1,31 +0,0 @@
{
"context": "Tests whether the agent knows the specific deterministic patch values used in codeflash's pytest plugin and the subprocess-based test execution architecture.",
"type": "weighted_checklist",
"checklist": [
{
"name": "Subprocess isolation",
"description": "States that tests run in a subprocess to isolate the test environment from the main codeflash process, NOT in the same process",
"max_score": 20
},
{
"name": "Fixed time value",
"description": "References the specific fixed timestamp 1761717605.108106 for time.time() or the fixed datetime 2021-01-01 02:05:10 UTC for datetime.now()",
"max_score": 20
},
{
"name": "Fixed UUID value",
"description": "References the specific fixed UUID 12345678-1234-5678-9abc-123456789012 for uuid4/uuid1",
"max_score": 20
},
{
"name": "Random seed",
"description": "States that random is seeded with 42 (NOT a different seed value)",
"max_score": 20
},
{
"name": "Plugin blocklists",
"description": "Mentions that behavioral tests block specific pytest plugins (at least 2 of: benchmark, codspeed, xdist, sugar) to ensure deterministic execution",
"max_score": 20
}
]
}

View file

@ -1,13 +0,0 @@
# Explain Test Reproducibility Guarantees
## Context
A codeflash user notices that their optimization candidate passes behavioral tests on one run but fails on the next. They suspect non-determinism in the test execution. They want to understand what guarantees codeflash provides for test reproducibility and how the system ensures consistent results.
## Task
Write a technical explanation of how codeflash ensures deterministic test execution. Cover the execution environment setup, what sources of non-determinism are controlled, and any specific values or configurations used. Also explain the test execution architecture.
## Expected Outputs
- A markdown file `test-reproducibility.md`

View file

@ -1 +0,0 @@
Effort level configuration and candidate selection criteria

View file

@ -1,26 +0,0 @@
{
"context": "Tests whether the agent knows the specific effort level values for candidate generation and the criteria used to select the best optimization candidate.",
"type": "weighted_checklist",
"checklist": [
{
"name": "Candidate counts by effort",
"description": "States correct N_OPTIMIZER_CANDIDATES values: LOW=3, MEDIUM=5, HIGH=6 (at least 2 of 3 correct)",
"max_score": 25
},
{
"name": "Speedup as primary selector",
"description": "States that the winning candidate is selected primarily by highest speedup ratio",
"max_score": 25
},
{
"name": "Diff length as tiebreaker",
"description": "States that for tied speedups, shortest diff length from original is used as tiebreaker",
"max_score": 25
},
{
"name": "Refinement ranking weights",
"description": "States that refinement candidates use weighted ranking with runtime weighted more heavily than diff (2:1 ratio or REFINED_CANDIDATE_RANKING_WEIGHTS=(2,1))",
"max_score": 25
}
]
}

View file

@ -1,18 +0,0 @@
# Design a Candidate Selection Dashboard
## Context
The codeflash team wants to build a dashboard that shows users how optimization candidates were evaluated and why a particular candidate won. The dashboard needs to display the selection process at each stage, from initial candidate pool through to the final winner.
## Task
Write a specification document for the dashboard that explains:
1. How many candidates are generated at each effort level
2. The exact criteria and order of operations used to pick the winning candidate
3. How refinement candidates are ranked differently from initial candidates
Include concrete examples showing how two hypothetical candidates would be compared.
## Expected Outputs
- A markdown file `selection-dashboard-spec.md`

Some files were not shown because too many files have changed in this diff Show more