chore: merge main into fixes-for-core-unstructured-experimental
This commit is contained in:
commit
c6fbdfa535
137 changed files with 4438 additions and 4106 deletions
4
.codex/config.toml
Normal file
4
.codex/config.toml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
[mcp_servers.tessl]
|
||||
type = "stdio"
|
||||
command = "tessl"
|
||||
args = [ "mcp", "start" ]
|
||||
2
.codex/skills/.gitignore
vendored
2
.codex/skills/.gitignore
vendored
|
|
@ -1,2 +0,0 @@
|
|||
# Managed by Tessl
|
||||
tessl:*
|
||||
12
.gemini/settings.json
Normal file
12
.gemini/settings.json
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"mcpServers": {
|
||||
"tessl": {
|
||||
"type": "stdio",
|
||||
"command": "tessl",
|
||||
"args": [
|
||||
"mcp",
|
||||
"start"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
2
.gemini/skills/.gitignore
vendored
2
.gemini/skills/.gitignore
vendored
|
|
@ -1,2 +0,0 @@
|
|||
# Managed by Tessl
|
||||
tessl:*
|
||||
108
.github/workflows/publish.yml
vendored
108
.github/workflows/publish.yml
vendored
|
|
@ -6,20 +6,48 @@ on:
|
|||
- main
|
||||
paths:
|
||||
- 'codeflash/version.py'
|
||||
- 'codeflash-benchmark/codeflash_benchmark/version.py'
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
detect-changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
codeflash: ${{ steps.filter.outputs.codeflash }}
|
||||
benchmark: ${{ steps.filter.outputs.benchmark }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Detect which packages changed
|
||||
id: filter
|
||||
run: |
|
||||
if git diff --name-only HEAD~1 HEAD | grep -q '^codeflash/version.py$'; then
|
||||
echo "codeflash=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "codeflash=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
if git diff --name-only HEAD~1 HEAD | grep -q '^codeflash-benchmark/codeflash_benchmark/version.py$'; then
|
||||
echo "benchmark=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "benchmark=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
publish-codeflash:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.codeflash == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write # Changed from 'read' to 'write' to allow tag creation
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # Fetch all history for proper versioning
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Extract version from version.py
|
||||
id: extract_version
|
||||
|
|
@ -76,4 +104,76 @@ jobs:
|
|||
prerelease: false
|
||||
generate_release_notes: true
|
||||
files: |
|
||||
dist/*
|
||||
dist/*
|
||||
|
||||
publish-benchmark:
|
||||
needs: detect-changes
|
||||
if: needs.detect-changes.outputs.benchmark == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Extract version from version.py
|
||||
id: extract_version
|
||||
run: |
|
||||
VERSION=$(grep -oP '__version__ = "\K[^"]+' codeflash-benchmark/codeflash_benchmark/version.py)
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "tag=benchmark-v$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Extracted version: $VERSION"
|
||||
|
||||
- name: Check if tag already exists
|
||||
id: check_tag
|
||||
run: |
|
||||
if git rev-parse "${{ steps.extract_version.outputs.tag }}" >/dev/null 2>&1; then
|
||||
echo "exists=true" >> $GITHUB_OUTPUT
|
||||
echo "Tag ${{ steps.extract_version.outputs.tag }} already exists, skipping release"
|
||||
else
|
||||
echo "exists=false" >> $GITHUB_OUTPUT
|
||||
echo "Tag ${{ steps.extract_version.outputs.tag }} does not exist, proceeding with release"
|
||||
fi
|
||||
|
||||
- name: Create and push git tag
|
||||
if: steps.check_tag.outputs.exists == 'false'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git tag -a "${{ steps.extract_version.outputs.tag }}" -m "Release ${{ steps.extract_version.outputs.tag }}"
|
||||
git push origin "${{ steps.extract_version.outputs.tag }}"
|
||||
|
||||
- name: Install uv
|
||||
if: steps.check_tag.outputs.exists == 'false'
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Build
|
||||
if: steps.check_tag.outputs.exists == 'false'
|
||||
run: uv build --package codeflash-benchmark
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: steps.check_tag.outputs.exists == 'false'
|
||||
run: uv publish
|
||||
|
||||
- name: Create GitHub Release
|
||||
if: steps.check_tag.outputs.exists == 'false'
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.extract_version.outputs.tag }}
|
||||
name: codeflash-benchmark ${{ steps.extract_version.outputs.tag }}
|
||||
body: |
|
||||
## What's Changed
|
||||
|
||||
Release ${{ steps.extract_version.outputs.version }} of codeflash-benchmark.
|
||||
|
||||
**Full Changelog**: https://github.com/${{ github.repository }}/commits/${{ steps.extract_version.outputs.tag }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
generate_release_notes: true
|
||||
files: |
|
||||
dist/*
|
||||
|
|
|
|||
20
.github/workflows/unit-tests.yaml
vendored
20
.github/workflows/unit-tests.yaml
vendored
|
|
@ -15,9 +15,25 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
python-version: "3.9"
|
||||
- os: ubuntu-latest
|
||||
python-version: "3.10"
|
||||
- os: ubuntu-latest
|
||||
python-version: "3.11"
|
||||
- os: ubuntu-latest
|
||||
python-version: "3.12"
|
||||
- os: ubuntu-latest
|
||||
python-version: "3.13"
|
||||
- os: ubuntu-latest
|
||||
python-version: "3.14"
|
||||
- os: windows-latest
|
||||
python-version: "3.13"
|
||||
continue-on-error: true
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.os }}
|
||||
env:
|
||||
PYTHONIOENCODING: utf-8
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
|
|
|||
34
.github/workflows/windows-unit-tests.yml
vendored
34
.github/workflows/windows-unit-tests.yml
vendored
|
|
@ -1,34 +0,0 @@
|
|||
name: windows-unit-tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
windows-unit-tests:
|
||||
continue-on-error: true
|
||||
runs-on: windows-latest
|
||||
env:
|
||||
PYTHONIOENCODING: utf-8
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.13"
|
||||
|
||||
- name: install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Unit tests
|
||||
run: uv run pytest tests/
|
||||
|
|
@ -28,6 +28,10 @@ Discovery → Ranking → Context Extraction → Test Gen + Optimization → Bas
|
|||
- **Tracer**: Profiling system that records function call trees and timings (`tracing/`, `tracer.py`)
|
||||
- **Worktree mode**: Git worktree-based parallel optimization (`--worktree` flag)
|
||||
|
||||
## PR Reviews
|
||||
|
||||
- GitHub PR comments and review feedback can be stale — they may reference issues already fixed by a later commit. Before acting on review feedback, verify it still applies to the current code. If the issue no longer exists, resolve the conversation in the GitHub UI.
|
||||
|
||||
<!-- Section below is auto-generated by `tessl install` - do not edit manually -->
|
||||
|
||||
# Agent Rules <!-- tessl-managed -->
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -18,6 +18,6 @@
|
|||
"@types/node": "^20.0.0",
|
||||
"codeflash": "file:../../../packages/codeflash",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^2.0.0"
|
||||
"vitest": "^4.0.18"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
"""CodeFlash Benchmark - Pytest benchmarking plugin for codeflash.ai."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
from codeflash_benchmark.version import __version__ as __version__
|
||||
|
|
|
|||
2
codeflash-benchmark/codeflash_benchmark/version.py
Normal file
2
codeflash-benchmark/codeflash_benchmark/version.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.3.0"
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "codeflash-benchmark"
|
||||
version = "0.2.0"
|
||||
dynamic = ["version"]
|
||||
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
|
||||
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
|
||||
requires-python = ">=3.9"
|
||||
|
|
@ -25,8 +25,19 @@ Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
|
|||
codeflash-benchmark = "codeflash_benchmark.plugin"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["hatchling", "uv-dynamic-versioning"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["codeflash_benchmark"]
|
||||
[tool.hatch.version]
|
||||
source = "uv-dynamic-versioning"
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
enable = true
|
||||
style = "pep440"
|
||||
vcs = "git"
|
||||
|
||||
[tool.hatch.build.hooks.version]
|
||||
path = "codeflash_benchmark/version.py"
|
||||
template = """# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "{version}"
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -130,9 +130,18 @@ def parse_args() -> Namespace:
|
|||
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
|
||||
)
|
||||
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
|
||||
parser.add_argument(
|
||||
"--subagent",
|
||||
action="store_true",
|
||||
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
|
||||
)
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
if args.subagent:
|
||||
args.yes = True
|
||||
args.no_pr = True
|
||||
args.worktree = True
|
||||
return process_and_validate_cmd_args(args)
|
||||
|
||||
|
||||
|
|
@ -352,32 +361,52 @@ def _handle_show_config() -> None:
|
|||
from codeflash.setup.detector import detect_project, has_existing_config
|
||||
|
||||
project_root = Path.cwd()
|
||||
detected = detect_project(project_root)
|
||||
config_exists, _ = has_existing_config(project_root)
|
||||
|
||||
# Check if config exists or is auto-detected
|
||||
config_exists, config_file = has_existing_config(project_root)
|
||||
status = "Saved config" if config_exists else "Auto-detected (not saved)"
|
||||
if config_exists:
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
if config_exists and config_file:
|
||||
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
|
||||
console.print()
|
||||
config, config_file_path = parse_config_file()
|
||||
status = "Saved config"
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
console.print(f"[dim]Config file: {config_file_path}[/dim]")
|
||||
console.print()
|
||||
|
||||
table.add_row("Language", detected.language)
|
||||
table.add_row("Project root", str(detected.project_root))
|
||||
table.add_row("Module root", str(detected.module_root))
|
||||
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
|
||||
table.add_row("Test runner", detected.test_runner or "(not detected)")
|
||||
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
|
||||
table.add_row(
|
||||
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
|
||||
)
|
||||
table.add_row("Confidence", f"{detected.confidence:.0%}")
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Project root", str(project_root))
|
||||
table.add_row("Module root", config.get("module_root", "(not set)"))
|
||||
table.add_row("Tests root", config.get("tests_root", "(not set)"))
|
||||
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
|
||||
table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)")
|
||||
ignore_paths = config.get("ignore_paths", [])
|
||||
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
|
||||
else:
|
||||
detected = detect_project(project_root)
|
||||
status = "Auto-detected (not saved)"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
console.print()
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Language", detected.language)
|
||||
table.add_row("Project root", str(detected.project_root))
|
||||
table.add_row("Module root", str(detected.module_root))
|
||||
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
|
||||
table.add_row("Test runner", detected.test_runner or "(not detected)")
|
||||
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
|
||||
table.add_row(
|
||||
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
|
||||
)
|
||||
table.add_row("Confidence", f"{detected.confidence:.0%}")
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from itertools import cycle
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
|
|
@ -19,43 +21,73 @@ from rich.progress import (
|
|||
|
||||
from codeflash.cli_cmds.console_constants import SPINNER_TYPES
|
||||
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode
|
||||
from codeflash.lsp.lsp_logger import enhanced_log
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
|
||||
from rich.progress import TaskID
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import DependencyResolver, IndexResult
|
||||
from codeflash.lsp.lsp_message import LspMessage
|
||||
from codeflash.models.models import TestResults
|
||||
|
||||
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
|
||||
|
||||
console = Console()
|
||||
|
||||
if is_LSP_enabled():
|
||||
if is_LSP_enabled() or is_subagent_mode():
|
||||
console.quiet = True
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
|
||||
format=BARE_LOGGING_FORMAT,
|
||||
)
|
||||
if is_subagent_mode():
|
||||
import re
|
||||
import sys
|
||||
|
||||
_lsp_prefix_re = re.compile(r"^(?:!?lsp,?|h[2-4]|loading)\|")
|
||||
_subagent_drop_patterns = (
|
||||
"Test log -",
|
||||
"Test failed to load",
|
||||
"Examining file ",
|
||||
"Generated ",
|
||||
"Add custom marker",
|
||||
"Disabling all autouse",
|
||||
"Reverting code and helpers",
|
||||
)
|
||||
|
||||
class _AgentLogFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.msg = _lsp_prefix_re.sub("", str(record.msg))
|
||||
msg = record.getMessage()
|
||||
return not any(msg.startswith(p) for p in _subagent_drop_patterns)
|
||||
|
||||
_agent_handler = logging.StreamHandler(sys.stderr)
|
||||
_agent_handler.addFilter(_AgentLogFilter())
|
||||
logging.basicConfig(level=logging.INFO, handlers=[_agent_handler], format="%(levelname)s: %(message)s")
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
|
||||
format=BARE_LOGGING_FORMAT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("rich")
|
||||
logging.getLogger("parso").setLevel(logging.WARNING)
|
||||
|
||||
# override the logger to reformat the messages for the lsp
|
||||
for level in ("info", "debug", "warning", "error"):
|
||||
real_fn = getattr(logger, level)
|
||||
setattr(
|
||||
logger,
|
||||
level,
|
||||
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
|
||||
msg, _real_fn, _level, *args, **kwargs
|
||||
),
|
||||
)
|
||||
if not is_subagent_mode():
|
||||
for level in ("info", "debug", "warning", "error"):
|
||||
real_fn = getattr(logger, level)
|
||||
setattr(
|
||||
logger,
|
||||
level,
|
||||
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
|
||||
msg, _real_fn, _level, *args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DummyTask:
|
||||
|
|
@ -82,6 +114,8 @@ def paneled_text(
|
|||
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
|
||||
) -> None:
|
||||
"""Print text in a panel."""
|
||||
if is_subagent_mode():
|
||||
return
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
|
|
@ -110,6 +144,8 @@ def code_print(
|
|||
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
|
||||
|
||||
"""
|
||||
if is_subagent_mode():
|
||||
return
|
||||
if is_LSP_enabled():
|
||||
lsp_log(
|
||||
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
|
||||
|
|
@ -147,6 +183,10 @@ def progress_bar(
|
|||
"""
|
||||
global _progress_bar_active
|
||||
|
||||
if is_subagent_mode():
|
||||
yield DummyTask().id
|
||||
return
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=message, takes_time=True))
|
||||
yield
|
||||
|
|
@ -178,6 +218,10 @@ def progress_bar(
|
|||
@contextmanager
|
||||
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
|
||||
"""Progress bar for test files."""
|
||||
if is_subagent_mode():
|
||||
yield DummyProgress(), DummyTask().id
|
||||
return
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=description, takes_time=True))
|
||||
dummy_progress = DummyProgress()
|
||||
|
|
@ -196,3 +240,242 @@ def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Pro
|
|||
) as progress:
|
||||
task_id = progress.add_task(description, total=total)
|
||||
yield progress, task_id
|
||||
|
||||
|
||||
MAX_TREE_ENTRIES = 8
|
||||
|
||||
|
||||
@contextmanager
|
||||
def call_graph_live_display(
|
||||
total: int, project_root: Path | None = None
|
||||
) -> Generator[Callable[[IndexResult], None], None, None]:
|
||||
from rich.console import Group
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.tree import Tree
|
||||
|
||||
if is_subagent_mode():
|
||||
yield lambda _: None
|
||||
return
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text="Building call graph", takes_time=True))
|
||||
yield lambda _: None
|
||||
return
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(next(spinners)),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(complete_style="cyan", finished_style="green", pulse_style="yellow"),
|
||||
MofNCompleteColumn(),
|
||||
TimeElapsedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
auto_refresh=False,
|
||||
)
|
||||
task_id = progress.add_task("Analyzing files", total=total)
|
||||
|
||||
results: deque[IndexResult] = deque(maxlen=MAX_TREE_ENTRIES)
|
||||
stats = {"indexed": 0, "cached": 0, "edges": 0, "external": 0, "errors": 0}
|
||||
|
||||
tree = Tree("[bold]Recent Files[/bold]")
|
||||
stats_text = Text("0 calls found", style="dim")
|
||||
panel = Panel(
|
||||
Group(progress, Text(""), tree, Text(""), stats_text), title="Building Call Graph", border_style="cyan"
|
||||
)
|
||||
|
||||
def create_tree_node(result: IndexResult) -> Tree:
|
||||
if project_root:
|
||||
try:
|
||||
name = str(result.file_path.resolve().relative_to(project_root.resolve()))
|
||||
except ValueError:
|
||||
name = f"{result.file_path.parent.name}/{result.file_path.name}"
|
||||
else:
|
||||
name = f"{result.file_path.parent.name}/{result.file_path.name}"
|
||||
|
||||
if result.error:
|
||||
return Tree(f"[red]{name} (error)[/red]")
|
||||
|
||||
if result.cached:
|
||||
return Tree(f"[dim]{name} (cached)[/dim]")
|
||||
|
||||
local_edges = result.num_edges - result.cross_file_edges
|
||||
edge_info = []
|
||||
|
||||
if local_edges:
|
||||
edge_info.append(f"{local_edges} calls in same file")
|
||||
if result.cross_file_edges:
|
||||
edge_info.append(f"{result.cross_file_edges} calls from other modules")
|
||||
|
||||
label = ", ".join(edge_info) if edge_info else "no calls"
|
||||
return Tree(f"[cyan]{name}[/cyan] [dim]{label}[/dim]")
|
||||
|
||||
def refresh_display() -> None:
|
||||
tree.children = [create_tree_node(r) for r in results]
|
||||
tree.children.extend([Tree(" ")] * (MAX_TREE_ENTRIES - len(results)))
|
||||
|
||||
# Update stats
|
||||
stat_parts = []
|
||||
if stats["indexed"]:
|
||||
stat_parts.append(f"{stats['indexed']} files analyzed")
|
||||
if stats["cached"]:
|
||||
stat_parts.append(f"{stats['cached']} cached")
|
||||
if stats["errors"]:
|
||||
stat_parts.append(f"{stats['errors']} errors")
|
||||
stat_parts.append(f"{stats['edges']} calls found")
|
||||
if stats["external"]:
|
||||
stat_parts.append(f"{stats['external']} cross-file calls")
|
||||
|
||||
stats_text.truncate(0)
|
||||
stats_text.append(" · ".join(stat_parts), style="dim")
|
||||
|
||||
batch: list[IndexResult] = []
|
||||
|
||||
def process_batch() -> None:
|
||||
for result in batch:
|
||||
results.append(result)
|
||||
|
||||
if result.error:
|
||||
stats["errors"] += 1
|
||||
elif result.cached:
|
||||
stats["cached"] += 1
|
||||
else:
|
||||
stats["indexed"] += 1
|
||||
stats["edges"] += result.num_edges
|
||||
stats["external"] += result.cross_file_edges
|
||||
|
||||
progress.advance(task_id)
|
||||
|
||||
batch.clear()
|
||||
refresh_display()
|
||||
live.refresh()
|
||||
|
||||
def update(result: IndexResult) -> None:
|
||||
batch.append(result)
|
||||
if len(batch) >= 8:
|
||||
process_batch()
|
||||
|
||||
with Live(panel, console=console, transient=False, auto_refresh=False) as live:
|
||||
yield update
|
||||
if batch:
|
||||
process_batch()
|
||||
|
||||
|
||||
def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path, list[FunctionToOptimize]]) -> None:
|
||||
total_functions = sum(map(len, file_to_funcs.values()))
|
||||
if not total_functions:
|
||||
return
|
||||
|
||||
if is_subagent_mode():
|
||||
return
|
||||
|
||||
# Build the mapping expected by the dependency resolver
|
||||
file_items = file_to_funcs.items()
|
||||
mapping = {file_path: {func.qualified_name for func in funcs} for file_path, funcs in file_items}
|
||||
|
||||
callee_counts = call_graph.count_callees_per_function(mapping)
|
||||
|
||||
# Use built-in sum for C-level loops to reduce Python overhead
|
||||
total_callees = sum(callee_counts.values())
|
||||
with_context = sum(1 for count in callee_counts.values() if count > 0)
|
||||
|
||||
leaf_functions = total_functions - with_context
|
||||
avg_callees = total_callees / total_functions
|
||||
|
||||
function_label = "function" if total_functions == 1 else "functions"
|
||||
|
||||
summary = (
|
||||
f"{total_functions} {function_label} ready for optimization\n"
|
||||
f"Uses other functions: {with_context} · "
|
||||
f"Standalone: {leaf_functions}"
|
||||
)
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=summary))
|
||||
return
|
||||
|
||||
console.print(Panel(summary, title="Call Graph Summary", border_style="cyan"))
|
||||
|
||||
|
||||
def subagent_log_optimization_result(
|
||||
function_name: str,
|
||||
file_path: Path,
|
||||
perf_improvement_line: str,
|
||||
original_runtime_ns: int,
|
||||
best_runtime_ns: int,
|
||||
raw_explanation: str,
|
||||
original_code: dict[Path, str],
|
||||
new_code: dict[Path, str],
|
||||
review: str,
|
||||
test_results: TestResults,
|
||||
) -> None:
|
||||
import sys
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from codeflash.code_utils.code_utils import unified_diff_strings
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
diff_parts = []
|
||||
for path in original_code:
|
||||
old = original_code.get(path, "")
|
||||
new = new_code.get(path, "")
|
||||
if old != new:
|
||||
diff = unified_diff_strings(old, new, fromfile=str(path), tofile=str(path))
|
||||
if diff:
|
||||
diff_parts.append(diff)
|
||||
|
||||
diff_str = "\n".join(diff_parts)
|
||||
|
||||
original_runtime = humanize_runtime(original_runtime_ns)
|
||||
optimized_runtime = humanize_runtime(best_runtime_ns)
|
||||
|
||||
report = test_results.get_test_pass_fail_report_by_type()
|
||||
verification_rows = []
|
||||
for test_type in TestType:
|
||||
if test_type is TestType.INIT_STATE_TEST:
|
||||
continue
|
||||
name = test_type.to_name()
|
||||
if not name:
|
||||
continue
|
||||
passed = report[test_type]["passed"]
|
||||
failed = report[test_type]["failed"]
|
||||
if passed == 0 and failed == 0:
|
||||
status = "None Found"
|
||||
elif failed > 0:
|
||||
status = f"{failed} Failed, {passed} Passed"
|
||||
else:
|
||||
status = f"{passed} Passed"
|
||||
verification_rows.append(f' <test type="{escape(name)}" status="{escape(status)}"/>')
|
||||
|
||||
xml = [
|
||||
"<codeflash-optimization>",
|
||||
f" <function>{escape(function_name)}</function>",
|
||||
f" <file>{escape(str(file_path))}</file>",
|
||||
f" <performance>{escape(perf_improvement_line)}</performance>",
|
||||
f" <original-runtime>{escape(original_runtime)}</original-runtime>",
|
||||
f" <optimized-runtime>{escape(optimized_runtime)}</optimized-runtime>",
|
||||
]
|
||||
if review:
|
||||
xml.append(f" <reviewer-confidence>{escape(review)}</reviewer-confidence>")
|
||||
xml.append(f" <explanation>{escape(raw_explanation.strip())}</explanation>")
|
||||
xml.append(" <verification>")
|
||||
xml.extend(verification_rows)
|
||||
xml.append(" </verification>")
|
||||
if diff_str:
|
||||
xml.append(f" <diff>{escape(diff_str)}</diff>")
|
||||
for path in new_code:
|
||||
if new_code[path] != original_code.get(path, ""):
|
||||
xml.append(f' <optimized-code file="{escape(str(path))}">{escape(new_code[path])}</optimized-code>')
|
||||
xml.append(" <action>")
|
||||
xml.append(" Present this optimization to the user using AskUserQuestion.")
|
||||
xml.append(" Show the function name, performance improvement, and explanation as the question text.")
|
||||
xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).")
|
||||
xml.append(
|
||||
" Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes."
|
||||
)
|
||||
xml.append(" If the user chooses 'Apply', write the content from optimized-code to the corresponding file.")
|
||||
xml.append(" </action>")
|
||||
xml.append("</codeflash-optimization>")
|
||||
|
||||
sys.stdout.write("\n".join(xml) + "\n")
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ def determine_js_package_manager(project_root: Path) -> JsPackageManager:
|
|||
"""
|
||||
# Search from project_root up to filesystem root for lock files
|
||||
# This supports monorepo setups where lock file is at workspace root
|
||||
current_dir = project_root.resolve()
|
||||
current_dir = project_root
|
||||
while current_dir != current_dir.parent:
|
||||
if (current_dir / "bun.lockb").exists() or (current_dir / "bun.lock").exists():
|
||||
return JsPackageManager.BUN
|
||||
|
|
@ -161,7 +161,7 @@ def find_node_modules_with_package(project_root: Path, package_name: str) -> Pat
|
|||
Path to the node_modules directory containing the package, or None if not found.
|
||||
|
||||
"""
|
||||
current_dir = project_root.resolve()
|
||||
current_dir = project_root
|
||||
while current_dir != current_dir.parent:
|
||||
node_modules = current_dir / "node_modules"
|
||||
if node_modules.exists():
|
||||
|
|
|
|||
|
|
@ -5,8 +5,18 @@ BARE_LOGGING_FORMAT = "%(message)s"
|
|||
|
||||
def set_level(level: int, *, echo_setting: bool = True) -> None:
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from codeflash.lsp.helpers import is_subagent_mode
|
||||
|
||||
if is_subagent_mode():
|
||||
logging.basicConfig(
|
||||
level=level, handlers=[logging.StreamHandler(sys.stderr)], format="%(levelname)s: %(message)s", force=True
|
||||
)
|
||||
logging.getLogger().setLevel(level)
|
||||
return
|
||||
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
|
|
|
|||
|
|
@ -141,12 +141,18 @@ def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dic
|
|||
|
||||
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
|
||||
previous_checkpoint_functions = None
|
||||
if getattr(args, "subagent", False):
|
||||
console.rule()
|
||||
return None
|
||||
if args.all and codeflash_temp_dir.is_dir():
|
||||
previous_checkpoint_functions = get_all_historical_functions(args.module_root, codeflash_temp_dir)
|
||||
if previous_checkpoint_functions and Confirm.ask(
|
||||
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
|
||||
default=True,
|
||||
console=console,
|
||||
if previous_checkpoint_functions and (
|
||||
getattr(args, "yes", False)
|
||||
or Confirm.ask(
|
||||
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
|
||||
default=True,
|
||||
console=console,
|
||||
)
|
||||
):
|
||||
console.rule()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -2,46 +2,16 @@ import os
|
|||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from platformdirs import user_config_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
codeflash_temp_dir: Path
|
||||
codeflash_cache_dir: Path
|
||||
codeflash_cache_db: Path
|
||||
LF: str = os.linesep
|
||||
IS_POSIX: bool = os.name != "nt"
|
||||
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
|
||||
|
||||
codeflash_cache_dir: Path = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
|
||||
|
||||
class Compat:
|
||||
# os-independent newline
|
||||
LF: str = os.linesep
|
||||
codeflash_temp_dir: Path = Path(tempfile.gettempdir()) / "codeflash"
|
||||
codeflash_temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
|
||||
|
||||
IS_POSIX: bool = os.name != "nt"
|
||||
|
||||
@property
|
||||
def codeflash_cache_dir(self) -> Path:
|
||||
return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
|
||||
|
||||
@property
|
||||
def codeflash_temp_dir(self) -> Path:
|
||||
temp_dir = Path(tempfile.gettempdir()) / "codeflash"
|
||||
if not temp_dir.exists():
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
return temp_dir
|
||||
|
||||
@property
|
||||
def codeflash_cache_db(self) -> Path:
|
||||
return self.codeflash_cache_dir / "codeflash_cache.db"
|
||||
|
||||
|
||||
_compat = Compat()
|
||||
|
||||
|
||||
codeflash_temp_dir = _compat.codeflash_temp_dir
|
||||
codeflash_cache_dir = _compat.codeflash_cache_dir
|
||||
codeflash_cache_db = _compat.codeflash_cache_db
|
||||
LF = _compat.LF
|
||||
SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE
|
||||
IS_POSIX = _compat.IS_POSIX
|
||||
codeflash_cache_db: Path = codeflash_cache_dir / "codeflash_cache.db"
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from enum import Enum
|
|||
from typing import Any, Union
|
||||
|
||||
MAX_TEST_RUN_ITERATIONS = 5
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 48000
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 64000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 64000
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
||||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
|
|
|
|||
|
|
@ -709,6 +709,7 @@ def inject_profiling_into_existing_test(
|
|||
tests_project_root: Path,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
tests_project_root = tests_project_root.resolve()
|
||||
if function_to_optimize.is_async:
|
||||
return inject_async_profiling_into_existing_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, mode
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
|
||||
def humanize_runtime(time_in_ns: int) -> str:
|
||||
runtime_human: str = str(time_in_ns)
|
||||
|
|
@ -89,3 +91,13 @@ def format_perf(percentage: float) -> str:
|
|||
if abs_perc >= 1:
|
||||
return f"{percentage:.2f}"
|
||||
return f"{percentage:.3f}"
|
||||
|
||||
|
||||
def format_runtime_comment(original_time_ns: int, optimized_time_ns: int, comment_prefix: str = "#") -> str:
|
||||
perf_gain = format_perf(
|
||||
abs(performance_gain(original_runtime_ns=original_time_ns, optimized_runtime_ns=optimized_time_ns) * 100)
|
||||
)
|
||||
status = "slower" if optimized_time_ns > original_time_ns else "faster"
|
||||
return (
|
||||
f"{comment_prefix} {format_time(original_time_ns)} -> {format_time(optimized_time_ns)} ({perf_gain}% {status})"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ FUNCTION_NAME_REGEX = re.compile(r"([^.]+)\.([a-zA-Z0-9_]+)$")
|
|||
class TestsCache:
|
||||
SCHEMA_VERSION = 1 # Increment this when schema changes
|
||||
|
||||
def __init__(self, project_root_path: str | Path) -> None:
|
||||
self.project_root_path = Path(project_root_path).resolve().as_posix()
|
||||
def __init__(self, project_root_path: Path) -> None:
|
||||
self.project_root_path = project_root_path.resolve().as_posix()
|
||||
self.connection = sqlite3.connect(codeflash_cache_db)
|
||||
self.cur = self.connection.cursor()
|
||||
|
||||
|
|
@ -728,6 +728,10 @@ def discover_tests_pytest(
|
|||
logger.debug(f"Pytest collection exit code: {exitcode}")
|
||||
if pytest_rootdir is not None:
|
||||
cfg.tests_project_rootdir = Path(pytest_rootdir)
|
||||
if discover_only_these_tests:
|
||||
resolved_discover_only = {p.resolve() for p in discover_only_these_tests}
|
||||
else:
|
||||
resolved_discover_only = None
|
||||
file_to_test_map: dict[Path, list[FunctionCalledInTest]] = defaultdict(list)
|
||||
for test in tests:
|
||||
if "__replay_test" in test["test_file"]:
|
||||
|
|
@ -737,13 +741,14 @@ def discover_tests_pytest(
|
|||
else:
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
|
||||
test_file_path = Path(test["test_file"]).resolve()
|
||||
test_obj = TestsInFile(
|
||||
test_file=Path(test["test_file"]),
|
||||
test_file=test_file_path,
|
||||
test_class=test["test_class"],
|
||||
test_function=test["test_function"],
|
||||
test_type=test_type,
|
||||
)
|
||||
if discover_only_these_tests and test_obj.test_file not in discover_only_these_tests:
|
||||
if resolved_discover_only and test_obj.test_file not in resolved_discover_only:
|
||||
continue
|
||||
file_to_test_map[test_obj.test_file].append(test_obj)
|
||||
# Within these test files, find the project functions they are referring to and return their names/locations
|
||||
|
|
|
|||
|
|
@ -114,38 +114,57 @@ class FunctionVisitor(cst.CSTVisitor):
|
|||
)
|
||||
|
||||
|
||||
class FunctionWithReturnStatement(ast.NodeVisitor):
|
||||
def __init__(self, file_path: Path) -> None:
|
||||
self.functions: list[FunctionToOptimize] = []
|
||||
self.ast_path: list[FunctionParent] = []
|
||||
self.file_path: Path = file_path
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> None:
|
||||
if function_has_return_statement(node) and not function_is_a_property(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
|
||||
)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
|
||||
if function_has_return_statement(node) and not function_is_a_property(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
|
||||
def find_functions_with_return_statement(ast_module: ast.Module, file_path: Path) -> list[FunctionToOptimize]:
|
||||
results: list[FunctionToOptimize] = []
|
||||
# (node, parent_path) — iterative DFS avoids RecursionError on deeply nested ASTs
|
||||
stack: list[tuple[ast.AST, list[FunctionParent]]] = [(ast_module, [])]
|
||||
while stack:
|
||||
node, ast_path = stack.pop()
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef)):
|
||||
if function_has_return_statement(node) and not function_is_a_property(node):
|
||||
results.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name,
|
||||
file_path=file_path,
|
||||
parents=ast_path[:],
|
||||
is_async=isinstance(node, AsyncFunctionDef),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> None:
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
|
||||
super().generic_visit(node)
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.pop()
|
||||
# Don't recurse into function bodies (matches original visitor behaviour)
|
||||
continue
|
||||
child_path = (
|
||||
[*ast_path, FunctionParent(node.name, node.__class__.__name__)] if isinstance(node, ClassDef) else ast_path
|
||||
)
|
||||
for child in reversed(list(ast.iter_child_nodes(node))):
|
||||
stack.append((child, child_path))
|
||||
return results
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multi-language support helpers
|
||||
# =============================================================================
|
||||
|
||||
_VCS_EXCLUDES = frozenset({".git", ".hg", ".svn"})
|
||||
|
||||
|
||||
def parse_dir_excludes(patterns: frozenset[str]) -> tuple[frozenset[str], tuple[str, ...], tuple[str, ...]]:
|
||||
"""Split glob patterns into exact names, prefixes, and suffixes.
|
||||
|
||||
Patterns ending with ``*`` become prefix matches, patterns starting with ``*``
|
||||
become suffix matches, and plain strings become exact matches.
|
||||
"""
|
||||
exact: set[str] = set()
|
||||
prefixes: list[str] = []
|
||||
suffixes: list[str] = []
|
||||
for p in patterns:
|
||||
if p.endswith("*"):
|
||||
prefixes.append(p[:-1])
|
||||
elif p.startswith("*"):
|
||||
suffixes.append(p[1:])
|
||||
else:
|
||||
exact.add(p)
|
||||
return frozenset(exact), tuple(prefixes), tuple(suffixes)
|
||||
|
||||
|
||||
def get_files_for_language(
|
||||
module_root_path: Path, ignore_paths: list[Path] | None = None, language: Language | None = None
|
||||
|
|
@ -164,37 +183,44 @@ def get_files_for_language(
|
|||
if ignore_paths is None:
|
||||
ignore_paths = []
|
||||
|
||||
all_patterns: frozenset[str]
|
||||
if language is not None:
|
||||
support = get_language_support(language)
|
||||
extensions = support.file_extensions
|
||||
all_patterns = support.dir_excludes | _VCS_EXCLUDES
|
||||
else:
|
||||
extensions = tuple(get_supported_extensions())
|
||||
all_patterns = _VCS_EXCLUDES
|
||||
for lang in Language:
|
||||
if is_language_supported(lang):
|
||||
all_patterns = all_patterns | get_language_support(lang).dir_excludes
|
||||
|
||||
# Default directory patterns to always exclude for JS/TS
|
||||
js_ts_default_excludes = {
|
||||
"node_modules",
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
"coverage",
|
||||
".cache",
|
||||
".turbo",
|
||||
".vercel",
|
||||
"__pycache__",
|
||||
}
|
||||
dir_excludes, prefixes, suffixes = parse_dir_excludes(all_patterns)
|
||||
|
||||
files = []
|
||||
for ext in extensions:
|
||||
pattern = f"*{ext}"
|
||||
for file_path in module_root_path.rglob(pattern):
|
||||
# Check explicit ignore paths
|
||||
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
|
||||
continue
|
||||
# Check default JS/TS excludes in path parts
|
||||
if any(part in js_ts_default_excludes for part in file_path.parts):
|
||||
continue
|
||||
files.append(file_path)
|
||||
ignore_dirs: set[str] = set()
|
||||
ignore_files: set[Path] = set()
|
||||
for p in ignore_paths:
|
||||
p = Path(p) if not isinstance(p, Path) else p
|
||||
if p.is_file():
|
||||
ignore_files.add(p)
|
||||
else:
|
||||
ignore_dirs.add(str(p))
|
||||
|
||||
files: list[Path] = []
|
||||
for dirpath, dirnames, filenames in os.walk(module_root_path):
|
||||
dirnames[:] = [
|
||||
d
|
||||
for d in dirnames
|
||||
if d not in dir_excludes
|
||||
and not (prefixes and d.startswith(prefixes))
|
||||
and not (suffixes and d.endswith(suffixes))
|
||||
and str(Path(dirpath) / d) not in ignore_dirs
|
||||
]
|
||||
for fname in filenames:
|
||||
if fname.endswith(extensions):
|
||||
fpath = Path(dirpath, fname)
|
||||
if fpath not in ignore_files:
|
||||
files.append(fpath)
|
||||
return files
|
||||
|
||||
|
||||
|
|
@ -237,9 +263,7 @@ def _find_all_functions_in_python_file(file_path: Path) -> dict[Path, list[Funct
|
|||
if DEBUG_MODE:
|
||||
logger.exception(e)
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor.visit(ast_module)
|
||||
functions[file_path] = function_name_visitor.functions
|
||||
functions[file_path] = find_functions_with_return_statement(ast_module, file_path)
|
||||
return functions
|
||||
|
||||
|
||||
|
|
@ -808,6 +832,7 @@ def filter_functions(
|
|||
*,
|
||||
disable_logs: bool = False,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
resolved_project_root = project_root.resolve()
|
||||
filtered_modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
blocklist_funcs = get_blocklisted_functions()
|
||||
logger.debug(f"Blocklisted functions: {blocklist_funcs}")
|
||||
|
|
@ -884,7 +909,7 @@ def filter_functions(
|
|||
lang_support = get_language_support(Path(file_path))
|
||||
if lang_support.language == Language.PYTHON:
|
||||
try:
|
||||
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
|
||||
ast.parse(f"import {module_name_from_file_path(Path(file_path), resolved_project_root)}")
|
||||
except SyntaxError:
|
||||
malformed_paths_count += 1
|
||||
continue
|
||||
|
|
@ -906,7 +931,10 @@ def filter_functions(
|
|||
if previous_checkpoint_functions:
|
||||
functions_tmp = []
|
||||
for function in _functions:
|
||||
if function.qualified_name_with_modules_from_root(project_root) in previous_checkpoint_functions:
|
||||
if (
|
||||
function.qualified_name_with_modules_from_root(resolved_project_root)
|
||||
in previous_checkpoint_functions
|
||||
):
|
||||
previous_checkpoint_functions_removed_count += 1
|
||||
continue
|
||||
functions_tmp.append(function)
|
||||
|
|
@ -960,12 +988,21 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
|
|||
|
||||
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
||||
# Custom DFS, return True as soon as a Return node is found
|
||||
stack: list[ast.AST] = [function_node]
|
||||
stack: list[ast.AST] = list(function_node.body)
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if isinstance(node, ast.Return):
|
||||
return True
|
||||
stack.extend(ast.iter_child_nodes(node))
|
||||
# Only push child nodes that are statements; Return nodes are statements,
|
||||
# so this preserves correctness while avoiding unnecessary traversal into expr/Name/etc.
|
||||
for field in getattr(node, "_fields", ()):
|
||||
child = getattr(node, field, None)
|
||||
if isinstance(child, list):
|
||||
for item in child:
|
||||
if isinstance(item, ast.stmt):
|
||||
stack.append(item)
|
||||
elif isinstance(child, ast.stmt):
|
||||
stack.append(child)
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ Usage:
|
|||
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
DependencyResolver,
|
||||
HelperFunction,
|
||||
IndexResult,
|
||||
Language,
|
||||
LanguageSupport,
|
||||
ParentInfo,
|
||||
|
|
@ -82,8 +84,10 @@ def __getattr__(name: str):
|
|||
|
||||
__all__ = [
|
||||
"CodeContext",
|
||||
"DependencyResolver",
|
||||
"FunctionInfo",
|
||||
"HelperFunction",
|
||||
"IndexResult",
|
||||
"Language",
|
||||
"LanguageSupport",
|
||||
"ParentInfo",
|
||||
|
|
|
|||
|
|
@ -11,10 +11,11 @@ from dataclasses import dataclass, field
|
|||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
|
||||
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
|
@ -34,6 +35,16 @@ def __getattr__(name: str) -> Any:
|
|||
raise AttributeError(msg)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IndexResult:
|
||||
file_path: Path
|
||||
cached: bool
|
||||
num_edges: int
|
||||
edges: tuple[tuple[str, str, bool], ...] # (caller_qn, callee_name, is_cross_file)
|
||||
cross_file_edges: int
|
||||
error: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class HelperFunction:
|
||||
"""A helper function that is a dependency of the target function.
|
||||
|
|
@ -192,6 +203,35 @@ class ReferenceInfo:
|
|||
caller_function: str | None = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class DependencyResolver(Protocol):
|
||||
"""Protocol for language-specific dependency resolution.
|
||||
|
||||
Implementations analyze source files to discover call-graph edges
|
||||
between functions so the optimizer can extract richer context.
|
||||
"""
|
||||
|
||||
def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None:
|
||||
"""Pre-index a batch of files."""
|
||||
...
|
||||
|
||||
def get_callees(
|
||||
self, file_path_to_qualified_names: dict[Path, set[str]]
|
||||
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
|
||||
"""Return callees for the given functions."""
|
||||
...
|
||||
|
||||
def count_callees_per_function(
|
||||
self, file_path_to_qualified_names: dict[Path, set[str]]
|
||||
) -> dict[tuple[Path, str], int]:
|
||||
"""Return the number of callees for each (file_path, qualified_name) pair."""
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release resources (e.g. database connections)."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LanguageSupport(Protocol):
|
||||
"""Protocol defining what a language implementation must provide.
|
||||
|
|
@ -254,6 +294,14 @@ class LanguageSupport(Protocol):
|
|||
"""Like # or //."""
|
||||
...
|
||||
|
||||
@property
|
||||
def dir_excludes(self) -> frozenset[str]:
|
||||
"""Directory name patterns to skip during file discovery.
|
||||
|
||||
Supports glob wildcards: "name" for exact, "prefix*" for startswith, "*suffix" for endswith.
|
||||
"""
|
||||
...
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
|
|
@ -490,6 +538,87 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def postprocess_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
|
||||
) -> GeneratedTestsList:
|
||||
"""Apply language-specific postprocessing to generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: Generated tests to update.
|
||||
test_framework: Test framework used for the project.
|
||||
project_root: Project root directory.
|
||||
source_file_path: Path to the source file under optimization.
|
||||
|
||||
Returns:
|
||||
Updated generated tests.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def remove_test_functions_from_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
|
||||
) -> GeneratedTestsList:
|
||||
"""Remove specific test functions from generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: Generated tests to update.
|
||||
functions_to_remove: List of function names to remove.
|
||||
|
||||
Returns:
|
||||
Updated generated tests.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def add_runtime_comments_to_generated_tests(
|
||||
self,
|
||||
generated_tests: GeneratedTestsList,
|
||||
original_runtimes: dict[InvocationId, list[int]],
|
||||
optimized_runtimes: dict[InvocationId, list[int]],
|
||||
tests_project_rootdir: Path | None = None,
|
||||
) -> GeneratedTestsList:
|
||||
"""Add runtime comments to generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: Generated tests to update.
|
||||
original_runtimes: Mapping of invocation IDs to original runtimes.
|
||||
optimized_runtimes: Mapping of invocation IDs to optimized runtimes.
|
||||
tests_project_rootdir: Root directory for tests (if applicable).
|
||||
|
||||
Returns:
|
||||
Updated generated tests.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str:
|
||||
"""Add new global declarations from optimized code to original source.
|
||||
|
||||
Args:
|
||||
optimized_code: The optimized code that may contain new declarations.
|
||||
original_source: The original source code.
|
||||
module_abspath: Path to the module file (for parser selection).
|
||||
|
||||
Returns:
|
||||
Original source with new declarations added.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is.
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
|
|
@ -556,6 +685,15 @@ class LanguageSupport(Protocol):
|
|||
# Default implementation: just copy runtime files
|
||||
return False
|
||||
|
||||
def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None:
|
||||
"""Create a language-specific dependency resolver, if available.
|
||||
|
||||
Returns:
|
||||
A DependencyResolver instance, or None if not supported.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def instrument_existing_test(
|
||||
self,
|
||||
test_path: Path,
|
||||
|
|
|
|||
217
codeflash/languages/javascript/code_replacer.py
Normal file
217
codeflash/languages/javascript/code_replacer.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""JavaScript/TypeScript code replacement helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def _add_global_declarations_for_language(
|
||||
optimized_code: str, original_source: str, module_abspath: Path, language: Language
|
||||
) -> str:
|
||||
"""Add new global declarations from optimized code to original source.
|
||||
|
||||
Finds module-level declarations (const, let, var, class, type, interface, enum)
|
||||
in the optimized code that don't exist in the original source and adds them.
|
||||
|
||||
New declarations are inserted after any existing declarations they depend on.
|
||||
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
|
||||
is already declared in the original source, `_has` will be inserted after `FOO`.
|
||||
|
||||
Args:
|
||||
optimized_code: The optimized code that may contain new declarations.
|
||||
original_source: The original source code.
|
||||
module_abspath: Path to the module file (for parser selection).
|
||||
language: The language of the code.
|
||||
|
||||
Returns:
|
||||
Original source with new declarations added in dependency order.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
|
||||
return original_source
|
||||
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(module_abspath)
|
||||
|
||||
original_declarations = analyzer.find_module_level_declarations(original_source)
|
||||
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
|
||||
|
||||
if not optimized_declarations:
|
||||
return original_source
|
||||
|
||||
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
|
||||
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
|
||||
|
||||
if not new_declarations:
|
||||
return original_source
|
||||
|
||||
# Build a map of existing declaration names to their end lines (1-indexed)
|
||||
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
|
||||
|
||||
# Insert each new declaration after its dependencies
|
||||
result = original_source
|
||||
for decl in new_declarations:
|
||||
result = _insert_declaration_after_dependencies(
|
||||
result, decl, existing_decl_end_lines, analyzer, module_abspath
|
||||
)
|
||||
# Update the map with the newly inserted declaration for subsequent insertions
|
||||
# Re-parse to get accurate line numbers after insertion
|
||||
updated_declarations = analyzer.find_module_level_declarations(result)
|
||||
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding global declarations: {e}")
|
||||
return original_source
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
|
||||
"""Get all names that already exist in the original source (declarations + imports)."""
|
||||
existing_names = {decl.name for decl in original_declarations}
|
||||
|
||||
original_imports = analyzer.find_imports(original_source)
|
||||
for imp in original_imports:
|
||||
if imp.default_import:
|
||||
existing_names.add(imp.default_import)
|
||||
for name, alias in imp.named_imports:
|
||||
existing_names.add(alias if alias else name)
|
||||
if imp.namespace_import:
|
||||
existing_names.add(imp.namespace_import)
|
||||
|
||||
return existing_names
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
|
||||
"""Filter declarations to only those that don't exist in the original source."""
|
||||
new_declarations = []
|
||||
seen_sources: set[str] = set()
|
||||
|
||||
# Sort by line number to maintain order from optimized code
|
||||
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
|
||||
|
||||
for decl in sorted_declarations:
|
||||
if decl.name not in existing_names and decl.source_code not in seen_sources:
|
||||
new_declarations.append(decl)
|
||||
seen_sources.add(decl.source_code)
|
||||
|
||||
return new_declarations
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def _insert_declaration_after_dependencies(
|
||||
source: str,
|
||||
declaration,
|
||||
existing_decl_end_lines: dict[str, int],
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
module_abspath: Path,
|
||||
) -> str:
|
||||
"""Insert a declaration after the last existing declaration it depends on.
|
||||
|
||||
Args:
|
||||
source: Current source code.
|
||||
declaration: The declaration to insert.
|
||||
existing_decl_end_lines: Map of existing declaration names to their end lines.
|
||||
analyzer: TreeSitter analyzer.
|
||||
module_abspath: Path to the module file.
|
||||
|
||||
Returns:
|
||||
Source code with the declaration inserted at the correct position.
|
||||
|
||||
"""
|
||||
# Find identifiers referenced in this declaration
|
||||
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
|
||||
|
||||
# Find the latest end line among all referenced declarations
|
||||
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
|
||||
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Ensure proper spacing
|
||||
decl_code = declaration.source_code
|
||||
if not decl_code.endswith("\n"):
|
||||
decl_code += "\n"
|
||||
|
||||
# Add blank line before if inserting after content
|
||||
if insertion_line > 0 and lines[insertion_line - 1].strip():
|
||||
decl_code = "\n" + decl_code
|
||||
|
||||
before = lines[:insertion_line]
|
||||
after = lines[insertion_line:]
|
||||
|
||||
return "".join([*before, decl_code, *after])
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def _find_insertion_line_for_declaration(
|
||||
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
|
||||
) -> int:
|
||||
"""Find the line where a declaration should be inserted based on its dependencies.
|
||||
|
||||
Args:
|
||||
source: Source code.
|
||||
referenced_names: Names referenced by the declaration.
|
||||
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
|
||||
analyzer: TreeSitter analyzer.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) where the declaration should be inserted.
|
||||
|
||||
"""
|
||||
# Find the maximum end line among referenced declarations
|
||||
max_dependency_line = 0
|
||||
for name in referenced_names:
|
||||
if name in existing_decl_end_lines:
|
||||
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
|
||||
|
||||
if max_dependency_line > 0:
|
||||
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
|
||||
return max_dependency_line
|
||||
|
||||
# No dependencies found - insert after imports
|
||||
lines = source.splitlines(keepends=True)
|
||||
return _find_line_after_imports(lines, analyzer, source)
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
|
||||
"""Find the line index after all imports.
|
||||
|
||||
Args:
|
||||
lines: Source lines.
|
||||
analyzer: TreeSitter analyzer.
|
||||
source: Full source code.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) for insertion after imports.
|
||||
|
||||
"""
|
||||
try:
|
||||
imports = analyzer.find_imports(source)
|
||||
if imports:
|
||||
return max(imp.end_line for imp in imports)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Exception in _find_line_after_imports: {exc}")
|
||||
|
||||
# Default: insert at beginning (after shebang/directive comments)
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
|
@ -6,29 +6,13 @@ including adding runtime comments and removing test functions.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
|
||||
def format_runtime_comment(original_time: int, optimized_time: int) -> str:
|
||||
"""Format a runtime comparison comment for JavaScript.
|
||||
|
||||
Args:
|
||||
original_time: Original runtime in nanoseconds.
|
||||
optimized_time: Optimized runtime in nanoseconds.
|
||||
|
||||
Returns:
|
||||
Formatted comment string with // prefix.
|
||||
|
||||
"""
|
||||
perf_gain = format_perf(
|
||||
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
|
||||
)
|
||||
status = "slower" if optimized_time > original_time else "faster"
|
||||
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
|
||||
from codeflash.code_utils.time_utils import format_runtime_comment
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
||||
|
||||
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
|
||||
|
|
@ -117,7 +101,7 @@ def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimiz
|
|||
# Only add comment if line has a function call and doesn't already have a comment
|
||||
if func_call_pattern.search(line) and "//" not in line and "expect(" in line:
|
||||
orig_time, opt_time = timing_by_full_name[current_matched_full_name]
|
||||
comment = format_runtime_comment(orig_time, opt_time)
|
||||
comment = format_runtime_comment(orig_time, opt_time, comment_prefix="//")
|
||||
logger.debug(f"[js-annotations] Adding comment to test '{current_test_name}': {comment}")
|
||||
# Add comment at end of line
|
||||
line = f"{line.rstrip()} {comment}"
|
||||
|
|
@ -130,6 +114,165 @@ def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimiz
|
|||
return "\n".join(modified_lines)
|
||||
|
||||
|
||||
JS_TEST_EXTENSIONS = (
|
||||
".test.ts",
|
||||
".test.js",
|
||||
".test.tsx",
|
||||
".test.jsx",
|
||||
".spec.ts",
|
||||
".spec.js",
|
||||
".spec.tsx",
|
||||
".spec.jsx",
|
||||
".ts",
|
||||
".js",
|
||||
".tsx",
|
||||
".jsx",
|
||||
".mjs",
|
||||
".mts",
|
||||
)
|
||||
|
||||
|
||||
# TODO:{self} Needs cleanup for jest logic in else block
|
||||
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
|
||||
def is_js_test_module_path(test_module_path: str) -> bool:
|
||||
"""Return True when the module path looks like a JS/TS test path."""
|
||||
return any(test_module_path.endswith(ext) for ext in JS_TEST_EXTENSIONS)
|
||||
|
||||
|
||||
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
|
||||
def resolve_js_test_module_path(test_module_path: str, tests_project_rootdir: Path) -> Path:
|
||||
"""Resolve a JS/TS test module path to a concrete file path."""
|
||||
if "/" in test_module_path or "\\" in test_module_path:
|
||||
return tests_project_rootdir / Path(test_module_path)
|
||||
|
||||
matched_ext = None
|
||||
for ext in JS_TEST_EXTENSIONS:
|
||||
if test_module_path.endswith(ext):
|
||||
matched_ext = ext
|
||||
break
|
||||
|
||||
if matched_ext:
|
||||
base_path = test_module_path[: -len(matched_ext)]
|
||||
file_path = base_path.replace(".", os.sep) + matched_ext
|
||||
tests_dir_name = tests_project_rootdir.name
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
return tests_project_rootdir.parent / Path(file_path)
|
||||
return tests_project_rootdir / Path(file_path)
|
||||
|
||||
return tests_project_rootdir / Path(test_module_path)
|
||||
|
||||
|
||||
# Patterns for normalizing codeflash imports (legacy -> npm package)
|
||||
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
|
||||
_CODEFLASH_REQUIRE_PATTERN = re.compile(
|
||||
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
|
||||
)
|
||||
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
|
||||
|
||||
|
||||
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
|
||||
def normalize_codeflash_imports(source: str) -> str:
|
||||
"""Normalize codeflash imports to use the npm package.
|
||||
|
||||
Replaces legacy local file imports:
|
||||
const codeflash = require('./codeflash-jest-helper')
|
||||
import codeflash from './codeflash-jest-helper'
|
||||
|
||||
With npm package imports:
|
||||
const codeflash = require('codeflash')
|
||||
|
||||
Args:
|
||||
source: JavaScript/TypeScript source code.
|
||||
|
||||
Returns:
|
||||
Source code with normalized imports.
|
||||
|
||||
"""
|
||||
# Replace CommonJS require
|
||||
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
||||
# Replace ES module import
|
||||
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList:
|
||||
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
|
||||
"""Inject test globals into all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
test_framework: The test framework being used ("jest", "vitest", or "mocha").
|
||||
|
||||
Returns:
|
||||
Generated tests with test globals injected.
|
||||
|
||||
"""
|
||||
# we only inject test globals for esm modules
|
||||
# Use vitest imports for vitest projects, jest imports for jest projects
|
||||
if test_framework == "vitest":
|
||||
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
|
||||
else:
|
||||
# Default to jest imports for jest and other frameworks
|
||||
global_import = (
|
||||
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
|
||||
)
|
||||
|
||||
for test in generated_tests.generated_tests:
|
||||
test.generated_original_test_source = global_import + test.generated_original_test_source
|
||||
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
|
||||
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
|
||||
return generated_tests
|
||||
|
||||
|
||||
# Author: ali <mohammed18200118@gmail.com>
|
||||
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
"""Disable TypeScript type checking in all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with TypeScript type checking disabled.
|
||||
|
||||
"""
|
||||
# we only inject test globals for esm modules
|
||||
ts_nocheck = "// @ts-nocheck\n"
|
||||
|
||||
for test in generated_tests.generated_tests:
|
||||
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
|
||||
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
|
||||
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
|
||||
return generated_tests
|
||||
|
||||
|
||||
# Author: Sarthak Agarwal <sarthak.saga@gmail.com>
|
||||
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
"""Normalize codeflash imports in all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with normalized imports.
|
||||
|
||||
"""
|
||||
normalized_tests = []
|
||||
for test in generated_tests.generated_tests:
|
||||
# Only normalize JS/TS files
|
||||
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
|
||||
normalized_test = GeneratedTests(
|
||||
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
|
||||
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
|
||||
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
normalized_tests.append(normalized_test)
|
||||
else:
|
||||
normalized_tests.append(test)
|
||||
return GeneratedTestsList(generated_tests=normalized_tests)
|
||||
|
||||
|
||||
def remove_test_functions(source: str, functions_to_remove: list[str]) -> str:
|
||||
"""Remove specific test functions from JavaScript test source code.
|
||||
|
||||
|
|
|
|||
|
|
@ -44,8 +44,7 @@ class ImportResolver:
|
|||
project_root: Root directory of the project.
|
||||
|
||||
"""
|
||||
# Resolve to real path to handle macOS symlinks like /var -> /private/var
|
||||
self.project_root = project_root.resolve()
|
||||
self.project_root = project_root
|
||||
self._resolution_cache: dict[tuple[Path, str], Path | None] = {}
|
||||
|
||||
def resolve_import(self, import_info: ImportInfo, source_file: Path) -> ResolvedImport | None:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.languages.base import ReferenceInfo
|
||||
from codeflash.languages.javascript.treesitter import TypeDefinition
|
||||
from codeflash.models.models import GeneratedTestsList, InvocationId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -63,6 +64,10 @@ class JavaScriptSupport:
|
|||
def comment_prefix(self) -> str:
|
||||
return "//"
|
||||
|
||||
@property
|
||||
def dir_excludes(self) -> frozenset[str]:
|
||||
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
|
|
@ -1774,6 +1779,116 @@ class JavaScriptSupport:
|
|||
|
||||
return remove_test_functions(test_source, functions_to_remove)
|
||||
|
||||
def postprocess_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
|
||||
) -> GeneratedTestsList:
|
||||
"""Apply language-specific postprocessing to generated tests."""
|
||||
from codeflash.languages.javascript.edit_tests import (
|
||||
disable_ts_check,
|
||||
inject_test_globals,
|
||||
normalize_generated_tests_imports,
|
||||
)
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
|
||||
module_system = detect_module_system(project_root, source_file_path)
|
||||
if module_system == "esm":
|
||||
generated_tests = inject_test_globals(generated_tests, test_framework)
|
||||
if self.language == Language.TYPESCRIPT:
|
||||
generated_tests = disable_ts_check(generated_tests)
|
||||
return normalize_generated_tests_imports(generated_tests)
|
||||
|
||||
def remove_test_functions_from_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
|
||||
) -> GeneratedTestsList:
|
||||
"""Remove specific test functions from generated tests."""
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
||||
updated_tests: list[GeneratedTests] = []
|
||||
for test in generated_tests.generated_tests:
|
||||
updated_tests.append(
|
||||
GeneratedTests(
|
||||
generated_original_test_source=self.remove_test_functions(
|
||||
test.generated_original_test_source, functions_to_remove
|
||||
),
|
||||
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
)
|
||||
return GeneratedTestsList(generated_tests=updated_tests)
|
||||
|
||||
def add_runtime_comments_to_generated_tests(
|
||||
self,
|
||||
generated_tests: GeneratedTestsList,
|
||||
original_runtimes: dict[InvocationId, list[int]],
|
||||
optimized_runtimes: dict[InvocationId, list[int]],
|
||||
tests_project_rootdir: Path | None = None,
|
||||
) -> GeneratedTestsList:
|
||||
"""Add runtime comments to generated tests."""
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
||||
tests_root = tests_project_rootdir or Path()
|
||||
original_runtimes_dict = self._build_runtime_map(original_runtimes, tests_root)
|
||||
optimized_runtimes_dict = self._build_runtime_map(optimized_runtimes, tests_root)
|
||||
|
||||
modified_tests: list[GeneratedTests] = []
|
||||
for test in generated_tests.generated_tests:
|
||||
modified_source = self.add_runtime_comments(
|
||||
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
|
||||
)
|
||||
modified_tests.append(
|
||||
GeneratedTests(
|
||||
generated_original_test_source=modified_source,
|
||||
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
)
|
||||
return GeneratedTestsList(generated_tests=modified_tests)
|
||||
|
||||
def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str:
|
||||
from codeflash.languages.javascript.code_replacer import _add_global_declarations_for_language
|
||||
|
||||
return _add_global_declarations_for_language(optimized_code, original_source, module_abspath, self.language)
|
||||
|
||||
def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
from codeflash.languages.javascript.treesitter import extract_calling_function_source
|
||||
|
||||
return extract_calling_function_source(source_code, function_name, ref_line)
|
||||
|
||||
def _build_runtime_map(
|
||||
self, inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path
|
||||
) -> dict[str, int]:
|
||||
from codeflash.languages.javascript.edit_tests import resolve_js_test_module_path
|
||||
|
||||
unique_inv_ids: dict[str, int] = {}
|
||||
for inv_id, runtimes in inv_id_runtimes.items():
|
||||
test_qualified_name = (
|
||||
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
|
||||
if inv_id.test_class_name
|
||||
else inv_id.test_function_name
|
||||
)
|
||||
if not test_qualified_name:
|
||||
continue
|
||||
abs_path = resolve_js_test_module_path(inv_id.test_module_path, tests_project_rootdir)
|
||||
|
||||
abs_path_str = str(abs_path.resolve().with_suffix(""))
|
||||
if "__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str:
|
||||
continue
|
||||
|
||||
key = test_qualified_name + "#" + abs_path_str
|
||||
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
|
||||
cur_invid = (
|
||||
inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1])
|
||||
) # type: ignore[union-attr]
|
||||
match_key = key + "#" + cur_invid
|
||||
if match_key not in unique_inv_ids:
|
||||
unique_inv_ids[match_key] = 0
|
||||
unique_inv_ids[match_key] += min(runtimes)
|
||||
return unique_inv_ids
|
||||
|
||||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
|
|
@ -2000,6 +2115,9 @@ class JavaScriptSupport:
|
|||
logger.error("Could not install codeflash. Please run: npm install --save-dev codeflash")
|
||||
return False
|
||||
|
||||
def create_dependency_resolver(self, project_root: Path) -> None:
|
||||
return None
|
||||
|
||||
def instrument_existing_test(
|
||||
self,
|
||||
test_path: Path,
|
||||
|
|
|
|||
|
|
@ -1788,3 +1788,39 @@ def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer:
|
|||
return TreeSitterAnalyzer(TreeSitterLanguage.TSX)
|
||||
# Default to JavaScript for .js, .jsx, .mjs, .cjs
|
||||
return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
|
||||
# Author: Saurabh Misra <misra.saurabh1@gmail.com>
|
||||
def extract_calling_function_source(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in JavaScript/TypeScript.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is (helps identify the right function).
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
# Try TypeScript first, fall back to JavaScript
|
||||
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
|
||||
try:
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
functions = analyzer.find_functions(source_code, include_methods=True)
|
||||
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
if func.start_line <= ref_line <= func.end_line:
|
||||
return func.source_text
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Python-specific implementations (LibCST, Jedi, pytest, etc.) to conform
|
|||
to the LanguageSupport protocol.
|
||||
"""
|
||||
|
||||
from codeflash.languages.python.reference_graph import ReferenceGraph
|
||||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
__all__ = ["PythonSupport"]
|
||||
__all__ = ["PythonSupport", "ReferenceGraph"]
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import TYPE_CHECKING
|
|||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
|
||||
|
|
@ -24,6 +23,10 @@ from codeflash.languages.python.context.unused_definition_remover import (
|
|||
recurse_sections,
|
||||
remove_unused_definitions_by_function_names,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.code_extractor import (
|
||||
add_needed_imports_from_module,
|
||||
find_preexisting_objects,
|
||||
)
|
||||
from codeflash.models.models import (
|
||||
CodeContextType,
|
||||
CodeOptimizationContext,
|
||||
|
|
@ -38,7 +41,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from jedi.api.classes import Name
|
||||
|
||||
from codeflash.languages.base import HelperFunction
|
||||
from codeflash.languages.base import DependencyResolver, HelperFunction
|
||||
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
|
||||
|
||||
# Error message constants
|
||||
|
|
@ -87,6 +90,7 @@ def get_code_optimization_context(
|
|||
project_root_path: Path,
|
||||
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
call_graph: DependencyResolver | None = None,
|
||||
) -> CodeOptimizationContext:
|
||||
# Route to language-specific implementation for non-Python languages
|
||||
if not is_python():
|
||||
|
|
@ -95,9 +99,11 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
# Get FunctionSource representation of helpers of FTO
|
||||
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
|
||||
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
|
||||
)
|
||||
fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}}
|
||||
if call_graph is not None:
|
||||
helpers_of_fto_dict, helpers_of_fto_list = call_graph.get_callees(fto_input)
|
||||
else:
|
||||
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(fto_input, project_root_path)
|
||||
|
||||
# Add function to optimize into helpers of FTO dict, as they'll be processed together
|
||||
fto_as_function_source = get_function_to_optimize_as_function_source(function_to_optimize, project_root_path)
|
||||
|
|
@ -113,8 +119,7 @@ def get_code_optimization_context(
|
|||
for qualified_names in helpers_of_fto_qualified_names_dict.values():
|
||||
qualified_names.update({f"{qn.rsplit('.', 1)[0]}.__init__" for qn in qualified_names if "." in qn})
|
||||
|
||||
# Get FunctionSource representation of helpers of helpers of FTO
|
||||
helpers_of_helpers_dict, _helpers_of_helpers_list = get_function_sources_from_jedi(
|
||||
helpers_of_helpers_dict, helpers_of_helpers_list = get_function_sources_from_jedi(
|
||||
helpers_of_fto_qualified_names_dict, project_root_path
|
||||
)
|
||||
|
||||
|
|
@ -198,6 +203,8 @@ def get_code_optimization_context(
|
|||
code_hash_context = hashing_code_context.markdown
|
||||
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
|
||||
|
||||
all_helper_fqns = list({fs.fully_qualified_name for fs in helpers_of_fto_list + helpers_of_helpers_list})
|
||||
|
||||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=final_read_writable_code,
|
||||
|
|
@ -205,6 +212,7 @@ def get_code_optimization_context(
|
|||
hashing_code_context=code_hash_context,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helpers_of_fto_list,
|
||||
testgen_helper_fqns=all_helper_fqns,
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
|
|
@ -266,7 +274,6 @@ def get_code_optimization_context_for_language(
|
|||
fully_qualified_name=helper.qualified_name,
|
||||
only_function_name=helper.name,
|
||||
source_code=helper.source_code,
|
||||
jedi_definition=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -335,13 +342,12 @@ def get_code_optimization_context_for_language(
|
|||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=read_writable_code,
|
||||
# Pass type definitions and globals as read-only context for the AI
|
||||
# This way the AI sees them as context but doesn't include them in optimized output
|
||||
read_only_context_code=code_context.read_only_context,
|
||||
hashing_code_context=read_writable_code.flat,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helper_function_sources,
|
||||
preexisting_objects=set(), # Not implemented for non-Python yet
|
||||
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
|
||||
preexisting_objects=set(),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -496,7 +502,6 @@ def get_function_to_optimize_as_function_source(
|
|||
fully_qualified_name=name.full_name,
|
||||
only_function_name=name.name,
|
||||
source_code=name.get_line_code(),
|
||||
jedi_definition=name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while getting function source: {e}")
|
||||
|
|
@ -533,6 +538,10 @@ def get_function_sources_from_jedi(
|
|||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
if definition_path is not None:
|
||||
rel = safe_relative_to(definition_path, project_root_path)
|
||||
if not rel.is_absolute():
|
||||
definition_path = project_root_path / rel
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
is_valid_definition = (
|
||||
|
|
@ -543,15 +552,16 @@ def get_function_sources_from_jedi(
|
|||
and not belongs_to_function_qualified(definition, qualified_function_name)
|
||||
and definition.full_name.startswith(definition.module_name)
|
||||
)
|
||||
if is_valid_definition and definition.type in ("function", "class"):
|
||||
if is_valid_definition and definition.type in ("function", "class", "statement"):
|
||||
if definition.type == "function":
|
||||
fqn = definition.full_name
|
||||
func_name = definition.name
|
||||
else:
|
||||
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
|
||||
# This ensures the class definition with constructor is included in testgen context
|
||||
elif definition.type == "class":
|
||||
fqn = f"{definition.full_name}.__init__"
|
||||
func_name = "__init__"
|
||||
else:
|
||||
fqn = definition.full_name
|
||||
func_name = definition.name
|
||||
qualified_name = get_qualified_name(definition.module_name, fqn)
|
||||
# Avoid nested functions or classes. Only class.function is allowed
|
||||
if len(qualified_name.split(".")) <= 2:
|
||||
|
|
@ -561,7 +571,6 @@ def get_function_sources_from_jedi(
|
|||
fully_qualified_name=fqn,
|
||||
only_function_name=func_name,
|
||||
source_code=definition.get_line_code(),
|
||||
jedi_definition=definition,
|
||||
)
|
||||
file_path_to_function_source[definition_path].add(function_source)
|
||||
function_source_list.append(function_source)
|
||||
|
|
@ -1006,6 +1015,255 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]:
|
||||
"""Import modules and resolve candidate (class_name, module_name) pairs to class objects."""
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
resolved: list[tuple[type, str]] = []
|
||||
module_cache: dict[str, object] = {}
|
||||
|
||||
for class_name, module_name in candidates:
|
||||
try:
|
||||
module = module_cache.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
module_cache[module_name] = module
|
||||
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is not None and inspect.isclass(cls):
|
||||
resolved.append((cls, class_name))
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to import {module_name}.{class_name}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
MAX_TRANSITIVE_DEPTH = 5
|
||||
|
||||
|
||||
def extract_classes_from_type_hint(hint: object) -> list[type]:
|
||||
"""Recursively extract concrete class objects from a type annotation.
|
||||
|
||||
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
|
||||
Filters out builtins and typing module types.
|
||||
"""
|
||||
import typing
|
||||
|
||||
classes: list[type] = []
|
||||
origin = getattr(hint, "__origin__", None)
|
||||
args = getattr(hint, "__args__", None)
|
||||
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
classes.extend(extract_classes_from_type_hint(arg))
|
||||
elif isinstance(hint, type):
|
||||
module = getattr(hint, "__module__", "")
|
||||
if module not in ("builtins", "typing", "typing_extensions", "types"):
|
||||
classes.append(hint)
|
||||
# Handle typing.Annotated on older Pythons where __origin__ may not be set
|
||||
if hasattr(typing, "get_args") and origin is None and args is None:
|
||||
try:
|
||||
inner_args = typing.get_args(hint)
|
||||
if inner_args:
|
||||
for arg in inner_args:
|
||||
classes.extend(extract_classes_from_type_hint(arg))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def resolve_transitive_type_deps(cls: type) -> list[type]:
|
||||
"""Find external classes referenced in cls.__init__ type annotations.
|
||||
|
||||
Returns classes from site-packages that have a custom __init__.
|
||||
"""
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
try:
|
||||
init_method = getattr(cls, "__init__")
|
||||
hints = typing.get_type_hints(init_method)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
deps: list[type] = []
|
||||
for param_name, hint in hints.items():
|
||||
if param_name == "return":
|
||||
continue
|
||||
for dep_cls in extract_classes_from_type_hint(hint):
|
||||
if dep_cls is cls:
|
||||
continue
|
||||
init_method = getattr(dep_cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
continue
|
||||
try:
|
||||
class_file = Path(inspect.getfile(dep_cls))
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
if not path_belongs_to_site_packages(class_file):
|
||||
continue
|
||||
deps.append(dep_cls)
|
||||
|
||||
return deps
|
||||
|
||||
|
||||
def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None:
|
||||
"""Extract a stub containing the class definition with only its __init__ method.
|
||||
|
||||
Args:
|
||||
cls: The class object to extract __init__ from
|
||||
class_name: Name to use for the class in the stub
|
||||
require_site_packages: If True, only extract from site-packages. If False, include stdlib too.
|
||||
|
||||
"""
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
init_method = getattr(cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
return None
|
||||
|
||||
try:
|
||||
class_file = Path(inspect.getfile(cls))
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
if require_site_packages and not path_belongs_to_site_packages(class_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
|
||||
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
|
||||
return CodeString(code=class_source, file_path=class_file)
|
||||
|
||||
|
||||
def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool:
|
||||
cached = cache.get(module_name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
is_project = _is_project_module(module_name, project_root_path)
|
||||
cache[module_name] = is_project
|
||||
return is_project
|
||||
|
||||
|
||||
def is_project_path(module_path: Path | None, project_root_path: Path) -> bool:
|
||||
if module_path is None:
|
||||
return False
|
||||
# site-packages must be checked first because .venv/site-packages is under project root
|
||||
if path_belongs_to_site_packages(module_path):
|
||||
return False
|
||||
try:
|
||||
module_path.resolve().relative_to(project_root_path.resolve())
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
|
||||
"""Check if a module is part of the project (not external/stdlib)."""
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
else:
|
||||
if spec is None or spec.origin is None:
|
||||
return False
|
||||
return is_project_path(Path(spec.origin), project_root_path)
|
||||
|
||||
|
||||
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
"""Extract import statements needed for a class definition.
|
||||
|
||||
This extracts imports for base classes, decorators, and type annotations.
|
||||
"""
|
||||
needed_names: set[str] = set()
|
||||
|
||||
# Get base class names
|
||||
for base in class_node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
needed_names.add(base.id)
|
||||
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
|
||||
# For things like abc.ABC, we need the module name
|
||||
needed_names.add(base.value.id)
|
||||
|
||||
# Get decorator names (e.g., dataclass, field)
|
||||
for decorator in class_node.decorator_list:
|
||||
if isinstance(decorator, ast.Name):
|
||||
needed_names.add(decorator.id)
|
||||
elif isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name):
|
||||
needed_names.add(decorator.func.id)
|
||||
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
|
||||
needed_names.add(decorator.func.value.id)
|
||||
|
||||
# Get type annotation names from class body (for dataclass fields)
|
||||
for item in class_node.body:
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
collect_names_from_annotation(item.annotation, needed_names)
|
||||
# Also check for field() calls which are common in dataclasses
|
||||
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
|
||||
if isinstance(item.value.func, ast.Name):
|
||||
needed_names.add(item.value.func.id)
|
||||
|
||||
# Find imports that provide these names
|
||||
import_lines: list[str] = []
|
||||
source_lines = module_source.split("\n")
|
||||
added_imports: set[int] = set() # Track line numbers to avoid duplicates
|
||||
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
|
||||
return "\n".join(import_lines)
|
||||
|
||||
|
||||
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
|
||||
"""Recursively collect type annotation names from an AST node."""
|
||||
if isinstance(node, ast.Name):
|
||||
names.add(node.id)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
collect_names_from_annotation(node.value, names)
|
||||
collect_names_from_annotation(node.slice, names)
|
||||
elif isinstance(node, ast.Tuple):
|
||||
for elt in node.elts:
|
||||
collect_names_from_annotation(elt, names)
|
||||
elif isinstance(node, ast.BinOp): # For Union types with | syntax
|
||||
collect_names_from_annotation(node.left, names)
|
||||
collect_names_from_annotation(node.right, names)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
|
||||
|
||||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from an indented block if it exists."""
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.languages import is_javascript
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -587,17 +587,20 @@ def revert_unused_helper_functions(
|
|||
|
||||
logger.debug(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
|
||||
|
||||
# Resolve all path keys for consistent comparison (Windows 8.3 short names may differ from Jedi-resolved paths)
|
||||
resolved_original_helper_code = {p.resolve(): code for p, code in original_helper_code.items()}
|
||||
|
||||
# Group unused helpers by file path
|
||||
unused_helpers_by_file = defaultdict(list)
|
||||
for helper in unused_helpers:
|
||||
unused_helpers_by_file[helper.file_path].append(helper)
|
||||
unused_helpers_by_file[helper.file_path.resolve()].append(helper)
|
||||
|
||||
# For each file, revert the unused helper functions to their original definitions
|
||||
for file_path, helpers_in_file in unused_helpers_by_file.items():
|
||||
if file_path in original_helper_code:
|
||||
if file_path in resolved_original_helper_code:
|
||||
try:
|
||||
# Get original code for this file
|
||||
original_code = original_helper_code[file_path]
|
||||
original_code = resolved_original_helper_code[file_path]
|
||||
|
||||
# Use the code replacer to selectively revert only the unused helper functions
|
||||
helper_names = [helper.qualified_name for helper in helpers_in_file]
|
||||
|
|
@ -640,15 +643,31 @@ def _analyze_imports_in_optimized_code(
|
|||
helpers_by_file_and_func = defaultdict(dict)
|
||||
helpers_by_file = defaultdict(list) # preserved for "import module"
|
||||
for helper in code_context.helper_functions:
|
||||
jedi_type = helper.jedi_definition.type if helper.jedi_definition else None
|
||||
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
|
||||
jedi_type = helper.definition_type
|
||||
if jedi_type != "class": # Include when definition_type is None (non-Python)
|
||||
func_name = helper.only_function_name
|
||||
module_name = helper.file_path.stem
|
||||
# Cache function lookup for this (module, func)
|
||||
helpers_by_file_and_func[module_name].setdefault(func_name, []).append(helper)
|
||||
helpers_by_file[module_name].append(helper)
|
||||
|
||||
for node in ast.walk(optimized_ast):
|
||||
# Collect only import nodes to avoid per-node isinstance checks across the whole AST
|
||||
class _ImportCollector(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.nodes: list[ast.AST] = []
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> None:
|
||||
self.nodes.append(node)
|
||||
# No need to recurse further for import nodes
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
||||
self.nodes.append(node)
|
||||
# No need to recurse further for import-from nodes
|
||||
|
||||
collector = _ImportCollector()
|
||||
collector.visit(optimized_ast)
|
||||
|
||||
for node in collector.nodes:
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
# Handle "from module import function" statements
|
||||
module_name = node.module
|
||||
|
|
@ -728,8 +747,8 @@ def detect_unused_helper_functions(
|
|||
|
||||
"""
|
||||
# Skip this analysis for non-Python languages since we use Python's ast module
|
||||
if is_javascript():
|
||||
logger.debug("Skipping unused helper function detection for JavaScript/TypeScript")
|
||||
if not is_python():
|
||||
logger.debug("Skipping unused helper function detection for non-Python languages")
|
||||
return []
|
||||
|
||||
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
|
||||
|
|
@ -799,8 +818,8 @@ def detect_unused_helper_functions(
|
|||
unused_helpers = []
|
||||
entrypoint_file_path = function_to_optimize.file_path
|
||||
for helper_function in code_context.helper_functions:
|
||||
jedi_type = helper_function.jedi_definition.type if helper_function.jedi_definition else None
|
||||
if jedi_type != "class": # Include when jedi_definition is None (non-Python)
|
||||
jedi_type = helper_function.definition_type
|
||||
if jedi_type != "class": # Include when definition_type is None (non-Python)
|
||||
# Check if the helper function is called using multiple name variants
|
||||
helper_qualified_name = helper_function.qualified_name
|
||||
helper_simple_name = helper_function.only_function_name
|
||||
|
|
|
|||
544
codeflash/languages/python/reference_graph.py
Normal file
544
codeflash/languages/python/reference_graph.py
Normal file
|
|
@ -0,0 +1,544 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import sqlite3
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.languages.base import IndexResult
|
||||
from codeflash.models.models import FunctionSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from jedi.api.classes import Name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (must be top-level for ProcessPoolExecutor pickling)
|
||||
# ---------------------------------------------------------------------------
|
||||
# TODO: create call graph.
|
||||
|
||||
_PARALLEL_THRESHOLD = 8
|
||||
|
||||
# Per-worker state, initialised by _init_index_worker in child processes
|
||||
_worker_jedi_project: object | None = None
|
||||
_worker_project_root_str: str | None = None
|
||||
|
||||
|
||||
def _init_index_worker(project_root: str) -> None:
|
||||
import jedi
|
||||
|
||||
global _worker_jedi_project, _worker_project_root_str
|
||||
_worker_jedi_project = jedi.Project(path=project_root)
|
||||
_worker_project_root_str = project_root
|
||||
|
||||
|
||||
def _resolve_definitions(ref: Name) -> list[Name]:
|
||||
try:
|
||||
inferred = ref.infer()
|
||||
valid = [d for d in inferred if d.type in ("function", "class")]
|
||||
if valid:
|
||||
return valid
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
result: list[Name] = ref.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
return result
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _is_valid_definition(definition: Name, caller_qualified_name: str, project_root_str: str) -> bool:
|
||||
definition_path = definition.module_path
|
||||
if definition_path is None:
|
||||
return False
|
||||
|
||||
if not str(definition_path).startswith(project_root_str + os.sep):
|
||||
return False
|
||||
|
||||
if path_belongs_to_site_packages(definition_path):
|
||||
return False
|
||||
|
||||
if not definition.full_name or not definition.full_name.startswith(definition.module_name):
|
||||
return False
|
||||
|
||||
if definition.type not in ("function", "class"):
|
||||
return False
|
||||
|
||||
try:
|
||||
def_qn = get_qualified_name(definition.module_name, definition.full_name)
|
||||
if def_qn == caller_qualified_name:
|
||||
return False
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
try:
|
||||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
if belongs_to_function_qualified(definition, caller_qualified_name):
|
||||
return False
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_enclosing_function_qn(ref: Name) -> str | None:
|
||||
try:
|
||||
parent = ref.parent()
|
||||
if parent is None or parent.type != "function":
|
||||
return None
|
||||
if not parent.full_name or not parent.full_name.startswith(parent.module_name):
|
||||
return None
|
||||
return get_qualified_name(parent.module_name, parent.full_name)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def _analyze_file(file_path: Path, jedi_project: object, project_root_str: str) -> tuple[set[tuple[str, ...]], bool]:
|
||||
"""Pure Jedi analysis — no DB access. Returns (edges, had_error)."""
|
||||
import jedi
|
||||
|
||||
resolved = str(file_path.resolve())
|
||||
|
||||
try:
|
||||
script = jedi.Script(path=file_path, project=jedi_project)
|
||||
refs = script.get_names(all_scopes=True, definitions=False, references=True)
|
||||
except Exception:
|
||||
return set(), True
|
||||
|
||||
edges: set[tuple[str, ...]] = set()
|
||||
|
||||
for ref in refs:
|
||||
try:
|
||||
caller_qn = _get_enclosing_function_qn(ref)
|
||||
if caller_qn is None:
|
||||
continue
|
||||
|
||||
definitions = _resolve_definitions(ref)
|
||||
if not definitions:
|
||||
continue
|
||||
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
if definition_path is None:
|
||||
continue
|
||||
|
||||
if not _is_valid_definition(definition, caller_qn, project_root_str):
|
||||
continue
|
||||
|
||||
edge_base = (resolved, caller_qn, str(definition_path))
|
||||
|
||||
if definition.type == "function":
|
||||
callee_qn = get_qualified_name(definition.module_name, definition.full_name)
|
||||
if len(callee_qn.split(".")) > 2:
|
||||
continue
|
||||
edges.add(
|
||||
(
|
||||
*edge_base,
|
||||
callee_qn,
|
||||
definition.full_name,
|
||||
definition.name,
|
||||
definition.type,
|
||||
definition.get_line_code(),
|
||||
)
|
||||
)
|
||||
elif definition.type == "class":
|
||||
init_qn = get_qualified_name(definition.module_name, f"{definition.full_name}.__init__")
|
||||
if len(init_qn.split(".")) > 2:
|
||||
continue
|
||||
edges.add(
|
||||
(
|
||||
*edge_base,
|
||||
init_qn,
|
||||
f"{definition.full_name}.__init__",
|
||||
"__init__",
|
||||
definition.type,
|
||||
definition.get_line_code(),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return edges, False
|
||||
|
||||
|
||||
def _index_file_worker(args: tuple[str, str]) -> tuple[str, str, set[tuple[str, ...]], bool]:
|
||||
"""Worker entry point for ProcessPoolExecutor."""
|
||||
file_path_str, file_hash = args
|
||||
assert _worker_project_root_str is not None
|
||||
edges, had_error = _analyze_file(Path(file_path_str), _worker_jedi_project, _worker_project_root_str)
|
||||
return file_path_str, file_hash, edges, had_error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ReferenceGraph:
|
||||
SCHEMA_VERSION = 2
|
||||
|
||||
def __init__(self, project_root: Path, language: str = "python", db_path: Path | None = None) -> None:
|
||||
import jedi
|
||||
|
||||
self.project_root = project_root.resolve()
|
||||
self.project_root_str = str(self.project_root)
|
||||
self.language = language
|
||||
self.jedi_project = jedi.Project(path=self.project_root)
|
||||
|
||||
if db_path is None:
|
||||
from codeflash.code_utils.compat import codeflash_cache_db
|
||||
|
||||
db_path = codeflash_cache_db
|
||||
|
||||
self.conn = sqlite3.connect(str(db_path))
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
self.indexed_file_hashes: dict[str, str] = {}
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("CREATE TABLE IF NOT EXISTS cg_schema_version (version INTEGER PRIMARY KEY)")
|
||||
|
||||
row = cur.execute("SELECT version FROM cg_schema_version LIMIT 1").fetchone()
|
||||
if row is None:
|
||||
cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
|
||||
elif row[0] != self.SCHEMA_VERSION:
|
||||
for table in [
|
||||
"cg_call_edges",
|
||||
"cg_indexed_files",
|
||||
"cg_languages",
|
||||
"cg_projects",
|
||||
"cg_project_meta",
|
||||
"indexed_files",
|
||||
"call_edges",
|
||||
]:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {table}")
|
||||
cur.execute("DELETE FROM cg_schema_version")
|
||||
cur.execute("INSERT INTO cg_schema_version (version) VALUES (?)", (self.SCHEMA_VERSION,))
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS indexed_files (
|
||||
project_root TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
file_hash TEXT NOT NULL,
|
||||
PRIMARY KEY (project_root, language, file_path)
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS call_edges (
|
||||
project_root TEXT NOT NULL,
|
||||
language TEXT NOT NULL,
|
||||
caller_file TEXT NOT NULL,
|
||||
caller_qualified_name TEXT NOT NULL,
|
||||
callee_file TEXT NOT NULL,
|
||||
callee_qualified_name TEXT NOT NULL,
|
||||
callee_fully_qualified_name TEXT NOT NULL,
|
||||
callee_only_function_name TEXT NOT NULL,
|
||||
callee_definition_type TEXT NOT NULL,
|
||||
callee_source_line TEXT NOT NULL,
|
||||
PRIMARY KEY (project_root, language, caller_file, caller_qualified_name,
|
||||
callee_file, callee_qualified_name)
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_call_edges_caller
|
||||
ON call_edges (project_root, language, caller_file, caller_qualified_name)
|
||||
"""
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_callees(
|
||||
self, file_path_to_qualified_names: dict[Path, set[str]]
|
||||
) -> tuple[dict[Path, set[FunctionSource]], list[FunctionSource]]:
|
||||
file_path_to_function_source: dict[Path, set[FunctionSource]] = defaultdict(set)
|
||||
function_source_list: list[FunctionSource] = []
|
||||
|
||||
all_caller_keys: list[tuple[str, str]] = []
|
||||
for file_path, qualified_names in file_path_to_qualified_names.items():
|
||||
resolved = str(file_path.resolve())
|
||||
self.ensure_file_indexed(file_path, resolved)
|
||||
all_caller_keys.extend((resolved, qn) for qn in qualified_names)
|
||||
|
||||
if not all_caller_keys:
|
||||
return file_path_to_function_source, function_source_list
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("CREATE TEMP TABLE IF NOT EXISTS _caller_keys (caller_file TEXT, caller_qualified_name TEXT)")
|
||||
cur.execute("DELETE FROM _caller_keys")
|
||||
cur.executemany("INSERT INTO _caller_keys VALUES (?, ?)", all_caller_keys)
|
||||
|
||||
rows = cur.execute(
|
||||
"""
|
||||
SELECT ce.callee_file, ce.callee_qualified_name, ce.callee_fully_qualified_name,
|
||||
ce.callee_only_function_name, ce.callee_definition_type, ce.callee_source_line
|
||||
FROM call_edges ce
|
||||
INNER JOIN _caller_keys ck
|
||||
ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name
|
||||
WHERE ce.project_root = ? AND ce.language = ?
|
||||
""",
|
||||
(self.project_root_str, self.language),
|
||||
).fetchall()
|
||||
|
||||
for callee_file, callee_qn, callee_fqn, callee_name, callee_type, callee_src in rows:
|
||||
callee_path = Path(callee_file)
|
||||
fs = FunctionSource(
|
||||
file_path=callee_path,
|
||||
qualified_name=callee_qn,
|
||||
fully_qualified_name=callee_fqn,
|
||||
only_function_name=callee_name,
|
||||
source_code=callee_src,
|
||||
definition_type=callee_type,
|
||||
)
|
||||
file_path_to_function_source[callee_path].add(fs)
|
||||
function_source_list.append(fs)
|
||||
|
||||
return file_path_to_function_source, function_source_list
|
||||
|
||||
def count_callees_per_function(
|
||||
self, file_path_to_qualified_names: dict[Path, set[str]]
|
||||
) -> dict[tuple[Path, str], int]:
|
||||
all_caller_keys: list[tuple[Path, str, str]] = []
|
||||
for file_path, qualified_names in file_path_to_qualified_names.items():
|
||||
resolved = str(file_path.resolve())
|
||||
self.ensure_file_indexed(file_path, resolved)
|
||||
all_caller_keys.extend((file_path, resolved, qn) for qn in qualified_names)
|
||||
|
||||
if not all_caller_keys:
|
||||
return {}
|
||||
|
||||
cur = self.conn.cursor()
|
||||
cur.execute("CREATE TEMP TABLE IF NOT EXISTS _count_keys (caller_file TEXT, caller_qualified_name TEXT)")
|
||||
cur.execute("DELETE FROM _count_keys")
|
||||
cur.executemany(
|
||||
"INSERT INTO _count_keys VALUES (?, ?)", [(resolved, qn) for _, resolved, qn in all_caller_keys]
|
||||
)
|
||||
|
||||
rows = cur.execute(
|
||||
"""
|
||||
SELECT ck.caller_file, ck.caller_qualified_name, COUNT(ce.rowid)
|
||||
FROM _count_keys ck
|
||||
LEFT JOIN call_edges ce
|
||||
ON ce.caller_file = ck.caller_file AND ce.caller_qualified_name = ck.caller_qualified_name
|
||||
AND ce.project_root = ? AND ce.language = ?
|
||||
GROUP BY ck.caller_file, ck.caller_qualified_name
|
||||
""",
|
||||
(self.project_root_str, self.language),
|
||||
).fetchall()
|
||||
|
||||
resolved_to_path: dict[str, Path] = {resolved: fp for fp, resolved, _ in all_caller_keys}
|
||||
counts: dict[tuple[Path, str], int] = {}
|
||||
for caller_file, caller_qn, cnt in rows:
|
||||
counts[(resolved_to_path[caller_file], caller_qn)] = cnt
|
||||
|
||||
return counts
|
||||
|
||||
def ensure_file_indexed(self, file_path: Path, resolved: str | None = None) -> IndexResult:
|
||||
if resolved is None:
|
||||
resolved = str(file_path.resolve())
|
||||
|
||||
# Always read and hash the file before checking the cache so we detect on-disk changes
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return IndexResult(file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True)
|
||||
|
||||
file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
if self._is_file_cached(resolved, file_hash):
|
||||
return IndexResult(file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False)
|
||||
|
||||
return self.index_file(file_path, file_hash, resolved)
|
||||
|
||||
def index_file(self, file_path: Path, file_hash: str, resolved: str | None = None) -> IndexResult:
|
||||
if resolved is None:
|
||||
resolved = str(file_path.resolve())
|
||||
edges, had_error = _analyze_file(file_path, self.jedi_project, self.project_root_str)
|
||||
if had_error:
|
||||
logger.debug(f"ReferenceGraph: failed to parse {file_path}")
|
||||
return self._persist_edges(file_path, resolved, file_hash, edges, had_error)
|
||||
|
||||
def _persist_edges(
|
||||
self, file_path: Path, resolved: str, file_hash: str, edges: set[tuple[str, ...]], had_error: bool
|
||||
) -> IndexResult:
|
||||
cur = self.conn.cursor()
|
||||
scope = (self.project_root_str, self.language)
|
||||
|
||||
# Clear existing data for this file
|
||||
cur.execute(
|
||||
"DELETE FROM call_edges WHERE project_root = ? AND language = ? AND caller_file = ?", (*scope, resolved)
|
||||
)
|
||||
cur.execute(
|
||||
"DELETE FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?", (*scope, resolved)
|
||||
)
|
||||
|
||||
# Insert new edges if parsing succeeded
|
||||
if not had_error and edges:
|
||||
cur.executemany(
|
||||
"""
|
||||
INSERT OR REPLACE INTO call_edges
|
||||
(project_root, language, caller_file, caller_qualified_name,
|
||||
callee_file, callee_qualified_name, callee_fully_qualified_name,
|
||||
callee_only_function_name, callee_definition_type, callee_source_line)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[(*scope, *edge) for edge in edges],
|
||||
)
|
||||
|
||||
# Record that this file has been indexed
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO indexed_files (project_root, language, file_path, file_hash) VALUES (?, ?, ?, ?)",
|
||||
(*scope, resolved, file_hash),
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
self.indexed_file_hashes[resolved] = file_hash
|
||||
|
||||
# Build summary for return value
|
||||
edges_summary = tuple(
|
||||
(caller_qn, callee_name, caller_file != callee_file)
|
||||
for (caller_file, caller_qn, callee_file, _, _, callee_name, _, _) in edges
|
||||
)
|
||||
cross_file_count = sum(is_cross_file for _, _, is_cross_file in edges_summary)
|
||||
|
||||
return IndexResult(
|
||||
file_path=file_path,
|
||||
cached=False,
|
||||
num_edges=len(edges),
|
||||
edges=edges_summary,
|
||||
cross_file_edges=cross_file_count,
|
||||
error=had_error,
|
||||
)
|
||||
|
||||
def build_index(self, file_paths: Iterable[Path], on_progress: Callable[[IndexResult], None] | None = None) -> None:
|
||||
"""Pre-index a batch of files, using multiprocessing for large uncached batches."""
|
||||
to_index: list[tuple[Path, str, str]] = []
|
||||
|
||||
for file_path in file_paths:
|
||||
resolved = str(file_path.resolve())
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
self._report_progress(
|
||||
on_progress,
|
||||
IndexResult(
|
||||
file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
file_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
# Check if already cached (in-memory or DB)
|
||||
if self._is_file_cached(resolved, file_hash):
|
||||
self._report_progress(
|
||||
on_progress,
|
||||
IndexResult(
|
||||
file_path=file_path, cached=True, num_edges=0, edges=(), cross_file_edges=0, error=False
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
to_index.append((file_path, resolved, file_hash))
|
||||
|
||||
if not to_index:
|
||||
return
|
||||
|
||||
# Index uncached files
|
||||
if len(to_index) >= _PARALLEL_THRESHOLD:
|
||||
self._build_index_parallel(to_index, on_progress)
|
||||
else:
|
||||
for file_path, resolved, file_hash in to_index:
|
||||
result = self.index_file(file_path, file_hash, resolved)
|
||||
self._report_progress(on_progress, result)
|
||||
|
||||
def _is_file_cached(self, resolved: str, file_hash: str) -> bool:
|
||||
"""Check if file is cached in memory or DB."""
|
||||
if self.indexed_file_hashes.get(resolved) == file_hash:
|
||||
return True
|
||||
|
||||
row = self.conn.execute(
|
||||
"SELECT file_hash FROM indexed_files WHERE project_root = ? AND language = ? AND file_path = ?",
|
||||
(self.project_root_str, self.language, resolved),
|
||||
).fetchone()
|
||||
|
||||
if row and row[0] == file_hash:
|
||||
self.indexed_file_hashes[resolved] = file_hash
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _report_progress(self, on_progress: Callable[[IndexResult], None] | None, result: IndexResult) -> None:
|
||||
"""Report progress if callback provided."""
|
||||
if on_progress is not None:
|
||||
on_progress(result)
|
||||
|
||||
def _build_index_parallel(
|
||||
self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None
|
||||
) -> None:
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
max_workers = min(os.cpu_count() or 1, len(to_index), 8)
|
||||
path_info: dict[str, tuple[Path, str]] = {resolved: (fp, fh) for fp, resolved, fh in to_index}
|
||||
worker_args = [(resolved, fh) for _fp, resolved, fh in to_index]
|
||||
|
||||
logger.debug(f"ReferenceGraph: indexing {len(to_index)} files across {max_workers} workers")
|
||||
|
||||
try:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=max_workers, initializer=_init_index_worker, initargs=(self.project_root_str,)
|
||||
) as executor:
|
||||
futures = {executor.submit(_index_file_worker, args): args[0] for args in worker_args}
|
||||
|
||||
for future in as_completed(futures):
|
||||
resolved = futures[future]
|
||||
file_path, file_hash = path_info[resolved]
|
||||
|
||||
try:
|
||||
_, _, edges, had_error = future.result()
|
||||
except Exception:
|
||||
logger.debug(f"ReferenceGraph: worker failed for {file_path}")
|
||||
self._persist_edges(file_path, resolved, file_hash, set(), had_error=True)
|
||||
self._report_progress(
|
||||
on_progress,
|
||||
IndexResult(
|
||||
file_path=file_path, cached=False, num_edges=0, edges=(), cross_file_edges=0, error=True
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
if had_error:
|
||||
logger.debug(f"ReferenceGraph: failed to parse {file_path}")
|
||||
|
||||
result = self._persist_edges(file_path, resolved, file_hash, edges, had_error)
|
||||
self._report_progress(on_progress, result)
|
||||
|
||||
except Exception:
|
||||
logger.debug("ReferenceGraph: parallel indexing failed, falling back to sequential")
|
||||
self._fallback_sequential_index(to_index, on_progress)
|
||||
|
||||
def _fallback_sequential_index(
|
||||
self, to_index: list[tuple[Path, str, str]], on_progress: Callable[[IndexResult], None] | None
|
||||
) -> None:
|
||||
"""Fallback to sequential indexing when parallel processing fails."""
|
||||
for file_path, resolved, file_hash in to_index:
|
||||
# Skip files already persisted before the failure
|
||||
if resolved in self.indexed_file_hashes:
|
||||
continue
|
||||
result = self.index_file(file_path, file_hash, resolved)
|
||||
self._report_progress(on_progress, result)
|
||||
|
||||
def close(self) -> None:
|
||||
self.conn.close()
|
||||
0
codeflash/languages/python/static_analysis/__init__.py
Normal file
0
codeflash/languages/python/static_analysis/__init__.py
Normal file
|
|
@ -1659,6 +1659,13 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro
|
|||
refs_by_file[ref.file_path] = []
|
||||
refs_by_file[ref.file_path].append(ref)
|
||||
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
try:
|
||||
lang_support = get_language_support(language)
|
||||
except Exception:
|
||||
lang_support = None
|
||||
|
||||
fn_call_context = ""
|
||||
context_len = 0
|
||||
|
||||
|
|
@ -1700,7 +1707,11 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro
|
|||
# Extract context around the reference
|
||||
if ref.caller_function:
|
||||
# Try to extract the full calling function
|
||||
func_code = _extract_calling_function(file_content, ref.caller_function, ref.line, language)
|
||||
func_code = None
|
||||
if lang_support is not None:
|
||||
func_code = lang_support.extract_calling_function_source(
|
||||
file_content, ref.caller_function, ref.line
|
||||
)
|
||||
if func_code:
|
||||
caller_contexts.append(func_code)
|
||||
context_len += len(func_code)
|
||||
|
|
@ -1718,77 +1729,3 @@ def _format_references_as_markdown(references: list, file_path: Path, project_ro
|
|||
fn_call_context += "\n```\n"
|
||||
|
||||
return fn_call_context
|
||||
|
||||
|
||||
def _extract_calling_function(source_code: str, function_name: str, ref_line: int, language: Language) -> str | None:
|
||||
"""Extract the source code of a calling function.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is.
|
||||
language: The programming language.
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
|
||||
"""
|
||||
if language == Language.PYTHON:
|
||||
return _extract_calling_function_python(source_code, function_name, ref_line)
|
||||
return _extract_calling_function_js(source_code, function_name, ref_line)
|
||||
|
||||
|
||||
def _extract_calling_function_python(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in Python."""
|
||||
try:
|
||||
import ast
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if node.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
start_line = node.lineno
|
||||
end_line = node.end_lineno or start_line
|
||||
if start_line <= ref_line <= end_line:
|
||||
return "\n".join(lines[start_line - 1 : end_line])
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_calling_function_js(source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in JavaScript/TypeScript.
|
||||
|
||||
Args:
|
||||
source_code: Full source code of the file.
|
||||
function_name: Name of the function to extract.
|
||||
ref_line: Line number where the reference is (helps identify the right function).
|
||||
|
||||
Returns:
|
||||
Source code of the function, or None if not found.
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
# Try TypeScript first, fall back to JavaScript
|
||||
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:
|
||||
try:
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
functions = analyzer.find_functions(source_code, include_methods=True)
|
||||
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Check if the reference line is within this function
|
||||
if func.start_line <= ref_line <= func.end_line:
|
||||
return func.source_text
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
|
@ -10,23 +10,22 @@ import libcst as cst
|
|||
from libcst.metadata import PositionProvider
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import (
|
||||
from codeflash.code_utils.config_parser import find_conftest_files
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.code_extractor import (
|
||||
add_global_assignments,
|
||||
add_needed_imports_from_module,
|
||||
find_insertion_index_after_imports,
|
||||
)
|
||||
from codeflash.code_utils.config_parser import find_conftest_files
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.code_utils.line_profile_utils import ImportAdder
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import ImportAdder
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
|
||||
from codeflash.languages.base import LanguageSupport
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
|
||||
|
||||
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
|
||||
|
|
@ -240,149 +239,6 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
|
|||
test_path.write_text(modified_module.code, encoding="utf-8")
|
||||
|
||||
|
||||
class OptimFunctionCollector(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None,
|
||||
function_names: set[tuple[str | None, str]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set()
|
||||
|
||||
self.function_names = function_names # set of (class_name, function_name)
|
||||
self.modified_functions: dict[
|
||||
tuple[str | None, str], cst.FunctionDef
|
||||
] = {} # keys are (class_name, function_name)
|
||||
self.new_functions: list[cst.FunctionDef] = []
|
||||
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
|
||||
self.new_classes: list[cst.ClassDef] = []
|
||||
self.current_class = None
|
||||
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
if (self.current_class, node.name.value) in self.function_names:
|
||||
self.modified_functions[(self.current_class, node.name.value)] = node
|
||||
elif self.current_class and node.name.value == "__init__":
|
||||
self.modified_init_functions[self.current_class] = node
|
||||
elif (
|
||||
self.preexisting_objects
|
||||
and (node.name.value, ()) not in self.preexisting_objects
|
||||
and self.current_class is None
|
||||
):
|
||||
self.new_functions.append(node)
|
||||
return False
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if self.current_class:
|
||||
return False # If already in a class, do not recurse deeper
|
||||
self.current_class = node.name.value
|
||||
|
||||
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
|
||||
|
||||
if (node.name.value, ()) not in self.preexisting_objects:
|
||||
self.new_classes.append(node)
|
||||
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
self.preexisting_objects
|
||||
and isinstance(child_node, cst.FunctionDef)
|
||||
and (child_node.name.value, parents) not in self.preexisting_objects
|
||||
):
|
||||
self.new_class_functions[node.name.value].append(child_node)
|
||||
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
if self.current_class:
|
||||
self.current_class = None
|
||||
|
||||
|
||||
class OptimFunctionReplacer(cst.CSTTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
|
||||
new_classes: Optional[list[cst.ClassDef]] = None,
|
||||
new_functions: Optional[list[cst.FunctionDef]] = None,
|
||||
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
|
||||
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.modified_functions = modified_functions if modified_functions is not None else {}
|
||||
self.new_functions = new_functions if new_functions is not None else []
|
||||
self.new_classes = new_classes if new_classes is not None else []
|
||||
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
|
||||
self.modified_init_functions: dict[str, cst.FunctionDef] = (
|
||||
modified_init_functions if modified_init_functions is not None else {}
|
||||
)
|
||||
self.current_class = None
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
if (self.current_class, original_node.name.value) in self.modified_functions:
|
||||
node = self.modified_functions[(self.current_class, original_node.name.value)]
|
||||
return updated_node.with_changes(body=node.body, decorators=node.decorators)
|
||||
if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions:
|
||||
return self.modified_init_functions[self.current_class]
|
||||
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if self.current_class:
|
||||
return False # If already in a class, do not recurse deeper
|
||||
self.current_class = node.name.value
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
if self.current_class and self.current_class == original_node.name.value:
|
||||
self.current_class = None
|
||||
if original_node.name.value in self.new_class_functions:
|
||||
return updated_node.with_changes(
|
||||
body=updated_node.body.with_changes(
|
||||
body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value]))
|
||||
)
|
||||
)
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
node = updated_node
|
||||
max_function_index = None
|
||||
max_class_index = None
|
||||
for index, _node in enumerate(node.body):
|
||||
if isinstance(_node, cst.FunctionDef):
|
||||
max_function_index = index
|
||||
if isinstance(_node, cst.ClassDef):
|
||||
max_class_index = index
|
||||
|
||||
if self.new_classes:
|
||||
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
|
||||
|
||||
unique_classes = [
|
||||
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
|
||||
]
|
||||
if unique_classes:
|
||||
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
|
||||
new_body = list(
|
||||
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
|
||||
)
|
||||
node = node.with_changes(body=new_body)
|
||||
|
||||
if max_function_index is not None:
|
||||
node = node.with_changes(
|
||||
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
|
||||
)
|
||||
elif max_class_index is not None:
|
||||
node = node.with_changes(
|
||||
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
|
||||
)
|
||||
else:
|
||||
node = node.with_changes(body=(*self.new_functions, *node.body))
|
||||
return node
|
||||
|
||||
|
||||
def replace_functions_in_file(
|
||||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
|
|
@ -401,23 +257,114 @@ def replace_functions_in_file(
|
|||
return source_code
|
||||
parsed_function_names.append((class_name, function_name))
|
||||
|
||||
# Collect functions we want to modify from the optimized code
|
||||
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
|
||||
# Collect functions from optimized code without using MetadataWrapper
|
||||
optimized_module = cst.parse_module(optimized_code)
|
||||
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = {}
|
||||
new_functions: list[cst.FunctionDef] = []
|
||||
new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
|
||||
new_classes: list[cst.ClassDef] = []
|
||||
modified_init_functions: dict[str, cst.FunctionDef] = {}
|
||||
|
||||
function_names_set = set(parsed_function_names)
|
||||
|
||||
for node in optimized_module.body:
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
key = (None, node.name.value)
|
||||
if key in function_names_set:
|
||||
modified_functions[key] = node
|
||||
elif preexisting_objects and (node.name.value, ()) not in preexisting_objects:
|
||||
new_functions.append(node)
|
||||
|
||||
elif isinstance(node, cst.ClassDef):
|
||||
class_name = node.name.value
|
||||
parents = (FunctionParent(name=class_name, type="ClassDef"),)
|
||||
|
||||
if (class_name, ()) not in preexisting_objects:
|
||||
new_classes.append(node)
|
||||
|
||||
for child in node.body.body:
|
||||
if isinstance(child, cst.FunctionDef):
|
||||
method_key = (class_name, child.name.value)
|
||||
if method_key in function_names_set:
|
||||
modified_functions[method_key] = child
|
||||
elif (
|
||||
child.name.value == "__init__"
|
||||
and preexisting_objects
|
||||
and (class_name, ()) in preexisting_objects
|
||||
):
|
||||
modified_init_functions[class_name] = child
|
||||
elif preexisting_objects and (child.name.value, parents) not in preexisting_objects:
|
||||
new_class_functions[class_name].append(child)
|
||||
|
||||
original_module = cst.parse_module(source_code)
|
||||
|
||||
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
|
||||
optimized_module.visit(visitor)
|
||||
max_function_index = None
|
||||
max_class_index = None
|
||||
for index, _node in enumerate(original_module.body):
|
||||
if isinstance(_node, cst.FunctionDef):
|
||||
max_function_index = index
|
||||
if isinstance(_node, cst.ClassDef):
|
||||
max_class_index = index
|
||||
|
||||
# Replace these functions in the original code
|
||||
transformer = OptimFunctionReplacer(
|
||||
modified_functions=visitor.modified_functions,
|
||||
new_classes=visitor.new_classes,
|
||||
new_functions=visitor.new_functions,
|
||||
new_class_functions=visitor.new_class_functions,
|
||||
modified_init_functions=visitor.modified_init_functions,
|
||||
)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
return modified_tree.code
|
||||
new_body: list[cst.CSTNode] = []
|
||||
existing_class_names = set()
|
||||
|
||||
for node in original_module.body:
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
key = (None, node.name.value)
|
||||
if key in modified_functions:
|
||||
modified_func = modified_functions[key]
|
||||
new_body.append(node.with_changes(body=modified_func.body, decorators=modified_func.decorators))
|
||||
else:
|
||||
new_body.append(node)
|
||||
|
||||
elif isinstance(node, cst.ClassDef):
|
||||
class_name = node.name.value
|
||||
existing_class_names.add(class_name)
|
||||
|
||||
new_members: list[cst.CSTNode] = []
|
||||
for child in node.body.body:
|
||||
if isinstance(child, cst.FunctionDef):
|
||||
key = (class_name, child.name.value)
|
||||
if key in modified_functions:
|
||||
modified_func = modified_functions[key]
|
||||
new_members.append(
|
||||
child.with_changes(body=modified_func.body, decorators=modified_func.decorators)
|
||||
)
|
||||
elif child.name.value == "__init__" and class_name in modified_init_functions:
|
||||
new_members.append(modified_init_functions[class_name])
|
||||
else:
|
||||
new_members.append(child)
|
||||
else:
|
||||
new_members.append(child)
|
||||
|
||||
if class_name in new_class_functions:
|
||||
new_members.extend(new_class_functions[class_name])
|
||||
|
||||
new_body.append(node.with_changes(body=node.body.with_changes(body=new_members)))
|
||||
else:
|
||||
new_body.append(node)
|
||||
|
||||
if new_classes:
|
||||
unique_classes = [nc for nc in new_classes if nc.name.value not in existing_class_names]
|
||||
if unique_classes:
|
||||
new_classes_insertion_idx = (
|
||||
max_class_index if max_class_index is not None else find_insertion_index_after_imports(original_module)
|
||||
)
|
||||
new_body = list(
|
||||
chain(new_body[:new_classes_insertion_idx], unique_classes, new_body[new_classes_insertion_idx:])
|
||||
)
|
||||
|
||||
if new_functions:
|
||||
if max_function_index is not None:
|
||||
new_body = [*new_body[: max_function_index + 1], *new_functions, *new_body[max_function_index + 1 :]]
|
||||
elif max_class_index is not None:
|
||||
new_body = [*new_body[: max_class_index + 1], *new_functions, *new_body[max_class_index + 1 :]]
|
||||
else:
|
||||
new_body = [*new_functions, *new_body]
|
||||
|
||||
updated_module = original_module.with_changes(body=new_body)
|
||||
return updated_module.code
|
||||
|
||||
|
||||
def replace_functions_and_add_imports(
|
||||
|
|
@ -509,11 +456,8 @@ def replace_function_definitions_for_language(
|
|||
lang_support = get_language_support(language)
|
||||
|
||||
# Add any new global declarations from the optimized code to the original source
|
||||
original_source_code = _add_global_declarations_for_language(
|
||||
optimized_code=code_to_apply,
|
||||
original_source=original_source_code,
|
||||
module_abspath=module_abspath,
|
||||
language=language,
|
||||
original_source_code = lang_support.add_global_declarations(
|
||||
optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath
|
||||
)
|
||||
|
||||
# If we have function_to_optimize with line info and this is the main file, use it for precise replacement
|
||||
|
|
@ -612,204 +556,6 @@ def _extract_function_from_code(
|
|||
return None
|
||||
|
||||
|
||||
def _add_global_declarations_for_language(
|
||||
optimized_code: str, original_source: str, module_abspath: Path, language: Language
|
||||
) -> str:
|
||||
"""Add new global declarations from optimized code to original source.
|
||||
|
||||
Finds module-level declarations (const, let, var, class, type, interface, enum)
|
||||
in the optimized code that don't exist in the original source and adds them.
|
||||
|
||||
New declarations are inserted after any existing declarations they depend on.
|
||||
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
|
||||
is already declared in the original source, `_has` will be inserted after `FOO`.
|
||||
|
||||
Args:
|
||||
optimized_code: The optimized code that may contain new declarations.
|
||||
original_source: The original source code.
|
||||
module_abspath: Path to the module file (for parser selection).
|
||||
language: The language of the code.
|
||||
|
||||
Returns:
|
||||
Original source with new declarations added in dependency order.
|
||||
|
||||
"""
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
|
||||
return original_source
|
||||
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(module_abspath)
|
||||
|
||||
original_declarations = analyzer.find_module_level_declarations(original_source)
|
||||
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
|
||||
|
||||
if not optimized_declarations:
|
||||
return original_source
|
||||
|
||||
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
|
||||
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
|
||||
|
||||
if not new_declarations:
|
||||
return original_source
|
||||
|
||||
# Build a map of existing declaration names to their end lines (1-indexed)
|
||||
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
|
||||
|
||||
# Insert each new declaration after its dependencies
|
||||
result = original_source
|
||||
for decl in new_declarations:
|
||||
result = _insert_declaration_after_dependencies(
|
||||
result, decl, existing_decl_end_lines, analyzer, module_abspath
|
||||
)
|
||||
# Update the map with the newly inserted declaration for subsequent insertions
|
||||
# Re-parse to get accurate line numbers after insertion
|
||||
updated_declarations = analyzer.find_module_level_declarations(result)
|
||||
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding global declarations: {e}")
|
||||
return original_source
|
||||
|
||||
|
||||
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
|
||||
"""Get all names that already exist in the original source (declarations + imports)."""
|
||||
existing_names = {decl.name for decl in original_declarations}
|
||||
|
||||
original_imports = analyzer.find_imports(original_source)
|
||||
for imp in original_imports:
|
||||
if imp.default_import:
|
||||
existing_names.add(imp.default_import)
|
||||
for name, alias in imp.named_imports:
|
||||
existing_names.add(alias if alias else name)
|
||||
if imp.namespace_import:
|
||||
existing_names.add(imp.namespace_import)
|
||||
|
||||
return existing_names
|
||||
|
||||
|
||||
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
|
||||
"""Filter declarations to only those that don't exist in the original source."""
|
||||
new_declarations = []
|
||||
seen_sources: set[str] = set()
|
||||
|
||||
# Sort by line number to maintain order from optimized code
|
||||
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
|
||||
|
||||
for decl in sorted_declarations:
|
||||
if decl.name not in existing_names and decl.source_code not in seen_sources:
|
||||
new_declarations.append(decl)
|
||||
seen_sources.add(decl.source_code)
|
||||
|
||||
return new_declarations
|
||||
|
||||
|
||||
def _insert_declaration_after_dependencies(
|
||||
source: str,
|
||||
declaration,
|
||||
existing_decl_end_lines: dict[str, int],
|
||||
analyzer: TreeSitterAnalyzer,
|
||||
module_abspath: Path,
|
||||
) -> str:
|
||||
"""Insert a declaration after the last existing declaration it depends on.
|
||||
|
||||
Args:
|
||||
source: Current source code.
|
||||
declaration: The declaration to insert.
|
||||
existing_decl_end_lines: Map of existing declaration names to their end lines.
|
||||
analyzer: TreeSitter analyzer.
|
||||
module_abspath: Path to the module file.
|
||||
|
||||
Returns:
|
||||
Source code with the declaration inserted at the correct position.
|
||||
|
||||
"""
|
||||
# Find identifiers referenced in this declaration
|
||||
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
|
||||
|
||||
# Find the latest end line among all referenced declarations
|
||||
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
|
||||
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Ensure proper spacing
|
||||
decl_code = declaration.source_code
|
||||
if not decl_code.endswith("\n"):
|
||||
decl_code += "\n"
|
||||
|
||||
# Add blank line before if inserting after content
|
||||
if insertion_line > 0 and lines[insertion_line - 1].strip():
|
||||
decl_code = "\n" + decl_code
|
||||
|
||||
before = lines[:insertion_line]
|
||||
after = lines[insertion_line:]
|
||||
|
||||
return "".join([*before, decl_code, *after])
|
||||
|
||||
|
||||
def _find_insertion_line_for_declaration(
|
||||
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
|
||||
) -> int:
|
||||
"""Find the line where a declaration should be inserted based on its dependencies.
|
||||
|
||||
Args:
|
||||
source: Source code.
|
||||
referenced_names: Names referenced by the declaration.
|
||||
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
|
||||
analyzer: TreeSitter analyzer.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) where the declaration should be inserted.
|
||||
|
||||
"""
|
||||
# Find the maximum end line among referenced declarations
|
||||
max_dependency_line = 0
|
||||
for name in referenced_names:
|
||||
if name in existing_decl_end_lines:
|
||||
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
|
||||
|
||||
if max_dependency_line > 0:
|
||||
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
|
||||
return max_dependency_line
|
||||
|
||||
# No dependencies found - insert after imports
|
||||
lines = source.splitlines(keepends=True)
|
||||
return _find_line_after_imports(lines, analyzer, source)
|
||||
|
||||
|
||||
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
|
||||
"""Find the line index after all imports.
|
||||
|
||||
Args:
|
||||
lines: Source lines.
|
||||
analyzer: TreeSitter analyzer.
|
||||
source: Full source code.
|
||||
|
||||
Returns:
|
||||
Line index (0-based) for insertion after imports.
|
||||
|
||||
"""
|
||||
try:
|
||||
imports = analyzer.find_imports(source)
|
||||
if imports:
|
||||
return max(imp.end_line for imp in imports)
|
||||
except Exception as exc:
|
||||
logger.debug(f"Exception in _find_line_after_imports: {exc}")
|
||||
|
||||
# Default: insert at beginning (after shebang/directive comments)
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):
|
||||
return i
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
|
||||
file_to_code_context = optimized_code.file_to_path()
|
||||
module_optimized_code = file_to_code_context.get(str(relative_path))
|
||||
|
|
@ -871,8 +617,7 @@ def replace_optimized_code(
|
|||
[
|
||||
callee.qualified_name
|
||||
for callee in code_context.helper_functions
|
||||
if callee.file_path == module_path
|
||||
and (callee.jedi_definition is None or callee.jedi_definition.type != "class")
|
||||
if callee.file_path == module_path and callee.definition_type != "class"
|
||||
]
|
||||
),
|
||||
candidate.source_code,
|
||||
|
|
@ -12,7 +12,6 @@ from libcst.metadata import PositionProvider
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
|
|
@ -155,7 +154,6 @@ def _is_python_file(file_path: Path) -> bool:
|
|||
return file_path.suffix == ".py"
|
||||
|
||||
|
||||
# TODO:{self} Needs cleanup for jest logic in else block
|
||||
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
|
||||
unique_inv_ids: dict[str, int] = {}
|
||||
logger.debug(f"[unique_inv_id] Processing {len(inv_id_runtimes)} invocation IDs")
|
||||
|
|
@ -166,53 +164,11 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_
|
|||
else inv_id.test_function_name
|
||||
)
|
||||
|
||||
# Detect if test_module_path is a file path (like in js tests) or a Python module name
|
||||
# File paths contain slashes, module names use dots
|
||||
test_module_path = inv_id.test_module_path
|
||||
if "/" in test_module_path or "\\" in test_module_path:
|
||||
# Already a file path - use directly
|
||||
abs_path = tests_project_rootdir / Path(test_module_path)
|
||||
else:
|
||||
# Check for Jest test file extensions (e.g., tests.fibonacci.test.ts)
|
||||
# These need special handling to avoid converting .test.ts -> /test/ts
|
||||
jest_test_extensions = (
|
||||
".test.ts",
|
||||
".test.js",
|
||||
".test.tsx",
|
||||
".test.jsx",
|
||||
".spec.ts",
|
||||
".spec.js",
|
||||
".spec.tsx",
|
||||
".spec.jsx",
|
||||
".ts",
|
||||
".js",
|
||||
".tsx",
|
||||
".jsx",
|
||||
".mjs",
|
||||
".mts",
|
||||
)
|
||||
matched_ext = None
|
||||
for ext in jest_test_extensions:
|
||||
if test_module_path.endswith(ext):
|
||||
matched_ext = ext
|
||||
break
|
||||
|
||||
if matched_ext:
|
||||
# JavaScript/TypeScript: convert module-style path to file path
|
||||
# "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts"
|
||||
base_path = test_module_path[: -len(matched_ext)]
|
||||
file_path = base_path.replace(".", os.sep) + matched_ext
|
||||
# Check if the module path includes the tests directory name
|
||||
tests_dir_name = tests_project_rootdir.name
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
# Module path includes "tests." - use parent directory
|
||||
abs_path = tests_project_rootdir.parent / Path(file_path)
|
||||
else:
|
||||
# Module path doesn't include tests dir - use tests root directly
|
||||
abs_path = tests_project_rootdir / Path(file_path)
|
||||
else:
|
||||
# Python module name - convert dots to path separators and add .py
|
||||
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
|
||||
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
|
||||
|
||||
abs_path_str = str(abs_path.resolve().with_suffix(""))
|
||||
# Include both unit test and perf test paths for runtime annotations
|
||||
|
|
@ -268,22 +224,7 @@ def add_runtime_comments_to_generated_tests(
|
|||
logger.debug(f"Failed to add runtime comments to test: {e}")
|
||||
modified_tests.append(test)
|
||||
else:
|
||||
try:
|
||||
language_support = get_language_support(test.behavior_file_path)
|
||||
modified_source = language_support.add_runtime_comments(
|
||||
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
|
||||
)
|
||||
modified_test = GeneratedTests(
|
||||
generated_original_test_source=modified_source,
|
||||
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
modified_tests.append(modified_test)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to add runtime comments to test: {e}")
|
||||
modified_tests.append(test)
|
||||
modified_tests.append(test)
|
||||
|
||||
return GeneratedTestsList(generated_tests=modified_tests)
|
||||
|
||||
|
|
@ -329,109 +270,3 @@ def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.P
|
|||
)
|
||||
for func in test_functions_to_remove
|
||||
]
|
||||
|
||||
|
||||
# Patterns for normalizing codeflash imports (legacy -> npm package)
|
||||
_CODEFLASH_REQUIRE_PATTERN = re.compile(
|
||||
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
|
||||
)
|
||||
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
|
||||
|
||||
|
||||
def normalize_codeflash_imports(source: str) -> str:
|
||||
"""Normalize codeflash imports to use the npm package.
|
||||
|
||||
Replaces legacy local file imports:
|
||||
const codeflash = require('./codeflash-jest-helper')
|
||||
import codeflash from './codeflash-jest-helper'
|
||||
|
||||
With npm package imports:
|
||||
const codeflash = require('codeflash')
|
||||
|
||||
Args:
|
||||
source: JavaScript/TypeScript source code.
|
||||
|
||||
Returns:
|
||||
Source code with normalized imports.
|
||||
|
||||
"""
|
||||
# Replace CommonJS require
|
||||
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
||||
# Replace ES module import
|
||||
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
||||
|
||||
|
||||
def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList:
|
||||
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
|
||||
"""Inject test globals into all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
test_framework: The test framework being used ("jest", "vitest", or "mocha").
|
||||
|
||||
Returns:
|
||||
Generated tests with test globals injected.
|
||||
|
||||
"""
|
||||
# we only inject test globals for esm modules
|
||||
# Use vitest imports for vitest projects, jest imports for jest projects
|
||||
if test_framework == "vitest":
|
||||
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
|
||||
else:
|
||||
# Default to jest imports for jest and other frameworks
|
||||
global_import = (
|
||||
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
|
||||
)
|
||||
|
||||
for test in generated_tests.generated_tests:
|
||||
test.generated_original_test_source = global_import + test.generated_original_test_source
|
||||
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
|
||||
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
|
||||
return generated_tests
|
||||
|
||||
|
||||
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
"""Disable TypeScript type checking in all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with TypeScript type checking disabled.
|
||||
|
||||
"""
|
||||
# we only inject test globals for esm modules
|
||||
ts_nocheck = "// @ts-nocheck\n"
|
||||
|
||||
for test in generated_tests.generated_tests:
|
||||
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
|
||||
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
|
||||
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
|
||||
return generated_tests
|
||||
|
||||
|
||||
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
||||
"""Normalize codeflash imports in all generated tests.
|
||||
|
||||
Args:
|
||||
generated_tests: List of generated tests.
|
||||
|
||||
Returns:
|
||||
Generated tests with normalized imports.
|
||||
|
||||
"""
|
||||
normalized_tests = []
|
||||
for test in generated_tests.generated_tests:
|
||||
# Only normalize JS/TS files
|
||||
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
|
||||
normalized_test = GeneratedTests(
|
||||
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
|
||||
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
|
||||
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
|
||||
behavior_file_path=test.behavior_file_path,
|
||||
perf_file_path=test.perf_file_path,
|
||||
)
|
||||
normalized_tests.append(normalized_test)
|
||||
else:
|
||||
normalized_tests.append(test)
|
||||
return GeneratedTestsList(generated_tests=normalized_tests)
|
||||
|
|
@ -21,7 +21,8 @@ from codeflash.languages.registry import register_language
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from codeflash.models.models import FunctionSource
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -75,6 +76,37 @@ class PythonSupport:
|
|||
def comment_prefix(self) -> str:
|
||||
return "#"
|
||||
|
||||
@property
|
||||
def dir_excludes(self) -> frozenset[str]:
|
||||
return frozenset(
|
||||
{
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
".pytest_cache",
|
||||
".hypothesis",
|
||||
"htmlcov",
|
||||
".pytype",
|
||||
".pyre",
|
||||
".pybuilder",
|
||||
".ipynb_checkpoints",
|
||||
".codeflash",
|
||||
".cache",
|
||||
".complexipy_cache",
|
||||
"build",
|
||||
"dist",
|
||||
"sdist",
|
||||
".coverage*",
|
||||
".pyright*",
|
||||
"*.egg-info",
|
||||
}
|
||||
)
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
|
|
@ -347,7 +379,7 @@ class PythonSupport:
|
|||
Modified source code with function replaced.
|
||||
|
||||
"""
|
||||
from codeflash.code_utils.code_replacer import replace_functions_in_file
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file
|
||||
|
||||
try:
|
||||
# Determine the function names to replace
|
||||
|
|
@ -625,6 +657,59 @@ class PythonSupport:
|
|||
except Exception:
|
||||
return test_source
|
||||
|
||||
def postprocess_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
|
||||
) -> GeneratedTestsList:
|
||||
"""Apply language-specific postprocessing to generated tests."""
|
||||
_ = test_framework, project_root, source_file_path
|
||||
return generated_tests
|
||||
|
||||
def remove_test_functions_from_generated_tests(
|
||||
self, generated_tests: GeneratedTestsList, functions_to_remove: list[str]
|
||||
) -> GeneratedTestsList:
|
||||
"""Remove specific test functions from generated tests."""
|
||||
from codeflash.languages.python.static_analysis.edit_generated_tests import (
|
||||
remove_functions_from_generated_tests,
|
||||
)
|
||||
|
||||
return remove_functions_from_generated_tests(generated_tests, functions_to_remove)
|
||||
|
||||
def add_runtime_comments_to_generated_tests(
|
||||
self,
|
||||
generated_tests: GeneratedTestsList,
|
||||
original_runtimes: dict[InvocationId, list[int]],
|
||||
optimized_runtimes: dict[InvocationId, list[int]],
|
||||
tests_project_rootdir: Path | None = None,
|
||||
) -> GeneratedTestsList:
|
||||
"""Add runtime comments to generated tests."""
|
||||
from codeflash.languages.python.static_analysis.edit_generated_tests import (
|
||||
add_runtime_comments_to_generated_tests,
|
||||
)
|
||||
|
||||
return add_runtime_comments_to_generated_tests(
|
||||
generated_tests, original_runtimes, optimized_runtimes, tests_project_rootdir
|
||||
)
|
||||
|
||||
def add_global_declarations(self, optimized_code: str, original_source: str, module_abspath: Path) -> str:
|
||||
_ = optimized_code, module_abspath
|
||||
return original_source
|
||||
|
||||
def extract_calling_function_source(self, source_code: str, function_name: str, ref_line: int) -> str | None:
|
||||
"""Extract the source code of a calling function in Python."""
|
||||
try:
|
||||
import ast
|
||||
|
||||
lines = source_code.splitlines()
|
||||
tree = ast.parse(source_code)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
|
||||
end_line = node.end_lineno or node.lineno
|
||||
if node.lineno <= ref_line <= end_line:
|
||||
return "\n".join(lines[node.lineno - 1 : end_line])
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
|
|
@ -720,6 +805,15 @@ class PythonSupport:
|
|||
"""
|
||||
return True
|
||||
|
||||
def create_dependency_resolver(self, project_root: Path) -> DependencyResolver | None:
|
||||
from codeflash.languages.python.reference_graph import ReferenceGraph
|
||||
|
||||
try:
|
||||
return ReferenceGraph(project_root, language=self.language.value)
|
||||
except Exception:
|
||||
logger.debug("Failed to initialize ReferenceGraph, falling back to per-function Jedi analysis")
|
||||
return None
|
||||
|
||||
def instrument_existing_test(
|
||||
self,
|
||||
test_path: Path,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ def is_LSP_enabled() -> bool:
|
|||
return os.getenv("CODEFLASH_LSP", default="false").lower() == "true"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_subagent_mode() -> bool:
|
||||
return os.getenv("CODEFLASH_SUBAGENT_MODE", default="false").lower() == "true"
|
||||
|
||||
|
||||
def tree_to_markdown(tree: Tree, level: int = 0) -> str:
|
||||
"""Convert a rich Tree into a Markdown bullet list."""
|
||||
indent = " " * level
|
||||
|
|
|
|||
|
|
@ -11,6 +11,12 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if "--subagent" in sys.argv:
|
||||
os.environ["CODEFLASH_SUBAGENT_MODE"] = "true"
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
|
||||
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
|
||||
from codeflash.cli_cmds.console import paneled_text
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from pathlib import Path
|
|||
from re import Pattern
|
||||
from typing import Any, NamedTuple, Optional, cast
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
|
@ -136,14 +135,14 @@ class CoverReturnCode(IntEnum):
|
|||
ERROR = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
@dataclass(frozen=True)
|
||||
class FunctionSource:
|
||||
file_path: Path
|
||||
qualified_name: str
|
||||
fully_qualified_name: str
|
||||
only_function_name: str
|
||||
source_code: str
|
||||
jedi_definition: Name | None = None # None for non-Python languages
|
||||
definition_type: str | None = None # e.g. "function", "class"; None for non-Python languages
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, FunctionSource):
|
||||
|
|
@ -380,6 +379,7 @@ class CodeOptimizationContext(BaseModel):
|
|||
hashing_code_context: str = ""
|
||||
hashing_code_context_hash: str = ""
|
||||
helper_functions: list[FunctionSource]
|
||||
testgen_helper_fqns: list[str] = []
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
|
|
@ -23,14 +24,15 @@ from rich.tree import Tree
|
|||
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
|
||||
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
|
||||
from codeflash.benchmarking.utils import process_benchmark_data
|
||||
from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code
|
||||
from codeflash.code_utils.code_replacer import (
|
||||
add_custom_marker_to_all_tests,
|
||||
modify_autouse_fixture,
|
||||
replace_function_definitions_in_module,
|
||||
from codeflash.cli_cmds.console import (
|
||||
code_print,
|
||||
console,
|
||||
logger,
|
||||
lsp_log,
|
||||
progress_bar,
|
||||
subagent_log_optimization_result,
|
||||
)
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import (
|
||||
choose_weights,
|
||||
cleanup_paths,
|
||||
|
|
@ -58,34 +60,32 @@ from codeflash.code_utils.config_consts import (
|
|||
get_effort_value,
|
||||
)
|
||||
from codeflash.code_utils.deduplicate_code import normalize_code
|
||||
from codeflash.code_utils.edit_generated_tests import (
|
||||
add_runtime_comments_to_generated_tests,
|
||||
disable_ts_check,
|
||||
inject_test_globals,
|
||||
normalize_generated_tests_imports,
|
||||
remove_functions_from_generated_tests,
|
||||
)
|
||||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
|
||||
from codeflash.code_utils.git_utils import git_root_dir
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import current_language_support, is_typescript
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
from codeflash.languages.current import current_language_support
|
||||
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
|
||||
from codeflash.languages.python.context import code_context_extractor
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
|
||||
from codeflash.languages.python.static_analysis.code_replacer import (
|
||||
add_custom_marker_to_all_tests,
|
||||
modify_autouse_fixture,
|
||||
replace_function_definitions_in_module,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
|
|
@ -139,6 +139,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import Result
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.models import (
|
||||
BenchmarkKey,
|
||||
CodeStringsMarkdown,
|
||||
|
|
@ -442,10 +443,14 @@ class FunctionOptimizer:
|
|||
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
args: Namespace | None = None,
|
||||
replay_tests_dir: Path | None = None,
|
||||
call_graph: DependencyResolver | None = None,
|
||||
) -> None:
|
||||
self.project_root = test_cfg.project_root_path
|
||||
self.project_root = test_cfg.project_root_path.resolve()
|
||||
self.test_cfg = test_cfg
|
||||
self.aiservice_client = aiservice_client if aiservice_client else AiServiceClient()
|
||||
resolved_file_path = function_to_optimize.file_path.resolve()
|
||||
if resolved_file_path != function_to_optimize.file_path:
|
||||
function_to_optimize = dataclasses.replace(function_to_optimize, file_path=resolved_file_path)
|
||||
self.function_to_optimize = function_to_optimize
|
||||
self.function_to_optimize_source_code = (
|
||||
function_to_optimize_source_code
|
||||
|
|
@ -484,6 +489,7 @@ class FunctionOptimizer:
|
|||
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
|
||||
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
|
||||
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
|
||||
self.call_graph = call_graph
|
||||
n_tests = get_effort_value(EffortKeys.N_GENERATED_TESTS, self.effort)
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
|
||||
|
|
@ -579,6 +585,7 @@ class FunctionOptimizer:
|
|||
test_results = self.generate_tests(
|
||||
testgen_context=code_context.testgen_context,
|
||||
helper_functions=code_context.helper_functions,
|
||||
testgen_helper_fqns=code_context.testgen_helper_fqns,
|
||||
generated_test_paths=generated_test_paths,
|
||||
generated_perf_test_paths=generated_perf_test_paths,
|
||||
)
|
||||
|
|
@ -588,16 +595,13 @@ class FunctionOptimizer:
|
|||
|
||||
count_tests, generated_tests, function_to_concolic_tests, concolic_test_str = test_results.unwrap()
|
||||
|
||||
# Normalize codeflash imports in JS/TS tests to use npm package
|
||||
if not is_python():
|
||||
module_system = detect_module_system(self.project_root, self.function_to_optimize.file_path)
|
||||
if module_system == "esm":
|
||||
generated_tests = inject_test_globals(generated_tests, self.test_cfg.test_framework)
|
||||
if is_typescript():
|
||||
# disable ts check for typescript tests
|
||||
generated_tests = disable_ts_check(generated_tests)
|
||||
|
||||
generated_tests = normalize_generated_tests_imports(generated_tests)
|
||||
# Language-specific postprocessing for generated tests
|
||||
generated_tests = self.language_support.postprocess_generated_tests(
|
||||
generated_tests,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
project_root=self.project_root,
|
||||
source_file_path=self.function_to_optimize.file_path,
|
||||
)
|
||||
|
||||
logger.debug(f"[PIPELINE] Processing {count_tests} generated tests")
|
||||
for i, generated_test in enumerate(generated_tests.generated_tests):
|
||||
|
|
@ -1352,6 +1356,8 @@ class FunctionOptimizer:
|
|||
def log_successful_optimization(
|
||||
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
|
||||
) -> None:
|
||||
if is_subagent_mode():
|
||||
return
|
||||
if is_LSP_enabled():
|
||||
md_lines = [
|
||||
"### ⚡️ Optimization Summary",
|
||||
|
|
@ -1450,7 +1456,7 @@ class FunctionOptimizer:
|
|||
optimized_code = ""
|
||||
if optimized_context is not None:
|
||||
file_to_code_context = optimized_context.file_to_path()
|
||||
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
|
||||
optimized_code = file_to_code_context.get(str(path.resolve().relative_to(self.project_root)), "")
|
||||
|
||||
new_code = format_code(
|
||||
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
|
||||
|
|
@ -1487,8 +1493,8 @@ class FunctionOptimizer:
|
|||
self.function_to_optimize.qualified_name
|
||||
)
|
||||
for helper_function in code_context.helper_functions:
|
||||
# Skip class definitions (jedi_definition may be None for non-Python languages)
|
||||
if helper_function.jedi_definition is None or helper_function.jedi_definition.type != "class":
|
||||
# Skip class definitions (definition_type may be None for non-Python languages)
|
||||
if helper_function.definition_type != "class":
|
||||
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
|
|
@ -1509,7 +1515,7 @@ class FunctionOptimizer:
|
|||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
try:
|
||||
new_code_ctx = code_context_extractor.get_code_optimization_context(
|
||||
self.function_to_optimize, self.project_root
|
||||
self.function_to_optimize, self.project_root, call_graph=self.call_graph
|
||||
)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
|
@ -1521,7 +1527,8 @@ class FunctionOptimizer:
|
|||
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||
hashing_code_context=new_code_ctx.hashing_code_context,
|
||||
hashing_code_context_hash=new_code_ctx.hashing_code_context_hash,
|
||||
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
|
||||
helper_functions=new_code_ctx.helper_functions,
|
||||
testgen_helper_fqns=new_code_ctx.testgen_helper_fqns,
|
||||
preexisting_objects=new_code_ctx.preexisting_objects,
|
||||
)
|
||||
)
|
||||
|
|
@ -1727,6 +1734,7 @@ class FunctionOptimizer:
|
|||
self,
|
||||
testgen_context: CodeStringsMarkdown,
|
||||
helper_functions: list[FunctionSource],
|
||||
testgen_helper_fqns: list[str],
|
||||
generated_test_paths: list[Path],
|
||||
generated_perf_test_paths: list[Path],
|
||||
) -> Result[tuple[int, GeneratedTestsList, dict[str, set[FunctionCalledInTest]], str], str]:
|
||||
|
|
@ -1735,23 +1743,29 @@ class FunctionOptimizer:
|
|||
assert len(generated_test_paths) == n_tests
|
||||
|
||||
if not self.args.no_gen_tests:
|
||||
# Submit test generation tasks
|
||||
helper_fqns = testgen_helper_fqns or [definition.fully_qualified_name for definition in helper_functions]
|
||||
future_tests = self.submit_test_generation_tasks(
|
||||
self.executor,
|
||||
testgen_context.markdown,
|
||||
[definition.fully_qualified_name for definition in helper_functions],
|
||||
generated_test_paths,
|
||||
generated_perf_test_paths,
|
||||
self.executor, testgen_context.markdown, helper_fqns, generated_test_paths, generated_perf_test_paths
|
||||
)
|
||||
|
||||
future_concolic_tests = self.executor.submit(
|
||||
generate_concolic_tests, self.test_cfg, self.args, self.function_to_optimize, self.function_to_optimize_ast
|
||||
)
|
||||
if is_subagent_mode():
|
||||
future_concolic_tests = None
|
||||
else:
|
||||
future_concolic_tests = self.executor.submit(
|
||||
generate_concolic_tests,
|
||||
self.test_cfg,
|
||||
self.args,
|
||||
self.function_to_optimize,
|
||||
self.function_to_optimize_ast,
|
||||
)
|
||||
|
||||
if not self.args.no_gen_tests:
|
||||
# Wait for test futures to complete
|
||||
concurrent.futures.wait([*future_tests, future_concolic_tests])
|
||||
else:
|
||||
futures_to_wait = [*future_tests]
|
||||
if future_concolic_tests is not None:
|
||||
futures_to_wait.append(future_concolic_tests)
|
||||
concurrent.futures.wait(futures_to_wait)
|
||||
elif future_concolic_tests is not None:
|
||||
concurrent.futures.wait([future_concolic_tests])
|
||||
# Process test generation results
|
||||
tests: list[GeneratedTests] = []
|
||||
|
|
@ -1780,7 +1794,10 @@ class FunctionOptimizer:
|
|||
logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}")
|
||||
return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}")
|
||||
|
||||
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
|
||||
if future_concolic_tests is not None:
|
||||
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
|
||||
else:
|
||||
function_to_concolic_tests, concolic_test_str = {}, None
|
||||
count_tests = len(tests)
|
||||
if concolic_test_str:
|
||||
count_tests += 1
|
||||
|
|
@ -2061,8 +2078,8 @@ class FunctionOptimizer:
|
|||
else "Coverage data not available"
|
||||
)
|
||||
|
||||
generated_tests = remove_functions_from_generated_tests(
|
||||
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
|
||||
generated_tests = self.language_support.remove_test_functions_from_generated_tests(
|
||||
generated_tests, test_functions_to_remove
|
||||
)
|
||||
map_gen_test_file_to_no_of_tests = original_code_baseline.behavior_test_results.file_to_no_of_tests(
|
||||
test_functions_to_remove
|
||||
|
|
@ -2073,7 +2090,7 @@ class FunctionOptimizer:
|
|||
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
|
||||
)
|
||||
|
||||
generated_tests = add_runtime_comments_to_generated_tests(
|
||||
generated_tests = self.language_support.add_runtime_comments_to_generated_tests(
|
||||
generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir
|
||||
)
|
||||
|
||||
|
|
@ -2203,7 +2220,20 @@ class FunctionOptimizer:
|
|||
self.optimization_review = opt_review_result.review
|
||||
|
||||
# Display the reviewer result to the user
|
||||
if opt_review_result.review:
|
||||
if is_subagent_mode():
|
||||
subagent_log_optimization_result(
|
||||
function_name=new_explanation.function_name,
|
||||
file_path=new_explanation.file_path,
|
||||
perf_improvement_line=new_explanation.perf_improvement_line,
|
||||
original_runtime_ns=new_explanation.original_runtime_ns,
|
||||
best_runtime_ns=new_explanation.best_runtime_ns,
|
||||
raw_explanation=new_explanation.raw_explanation_message,
|
||||
original_code=original_code_combined,
|
||||
new_code=new_code_combined,
|
||||
review=opt_review_result.review,
|
||||
test_results=new_explanation.winning_behavior_test_results,
|
||||
)
|
||||
elif opt_review_result.review:
|
||||
review_display = {
|
||||
"high": ("[bold green]High[/bold green]", "green", "Recommended to merge"),
|
||||
"medium": ("[bold yellow]Medium[/bold yellow]", "yellow", "Review recommended before merging"),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,13 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.api.cfapi import send_completion_email
|
||||
from codeflash.cli_cmds.console import console, logger, progress_bar
|
||||
from codeflash.cli_cmds.console import ( # noqa: F401
|
||||
call_graph_live_display,
|
||||
call_graph_summary,
|
||||
console,
|
||||
logger,
|
||||
progress_bar,
|
||||
)
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
|
||||
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
|
||||
|
|
@ -24,7 +30,8 @@ from codeflash.code_utils.git_worktree_utils import (
|
|||
)
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.languages import is_javascript, set_current_language
|
||||
from codeflash.languages import current_language_support, is_javascript, set_current_language
|
||||
from codeflash.lsp.helpers import is_subagent_mode
|
||||
from codeflash.models.models import ValidCode
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
@ -35,6 +42,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.benchmarking.function_ranker import FunctionRanker
|
||||
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
|
|
@ -241,8 +249,11 @@ class Optimizer:
|
|||
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
|
||||
original_module_ast: ast.Module | None = None,
|
||||
original_module_path: Path | None = None,
|
||||
call_graph: DependencyResolver | None = None,
|
||||
) -> FunctionOptimizer | None:
|
||||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.languages.python.static_analysis.static_analysis import (
|
||||
get_first_top_level_function_or_method_ast,
|
||||
)
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
if function_to_optimize_ast is None and original_module_ast is not None:
|
||||
|
|
@ -279,13 +290,14 @@ class Optimizer:
|
|||
function_benchmark_timings=function_specific_timings,
|
||||
total_benchmark_timings=total_benchmark_timings if function_specific_timings else None,
|
||||
replay_tests_dir=self.replay_tests_dir,
|
||||
call_graph=call_graph,
|
||||
)
|
||||
|
||||
def prepare_module_for_optimization(
|
||||
self, original_module_path: Path
|
||||
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
|
||||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
|
||||
from codeflash.code_utils.static_analysis import analyze_imported_modules
|
||||
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
|
||||
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
|
||||
|
||||
logger.info(f"loading|Examining file {original_module_path!s}")
|
||||
console.rule()
|
||||
|
|
@ -422,7 +434,10 @@ class Optimizer:
|
|||
console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]")
|
||||
|
||||
def rank_all_functions_globally(
|
||||
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None
|
||||
self,
|
||||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]],
|
||||
trace_file_path: Path | None,
|
||||
call_graph: DependencyResolver | None = None,
|
||||
) -> list[tuple[Path, FunctionToOptimize]]:
|
||||
"""Rank all functions globally across all files based on trace data.
|
||||
|
||||
|
|
@ -442,8 +457,10 @@ class Optimizer:
|
|||
for file_path, functions in file_to_funcs_to_optimize.items():
|
||||
all_functions.extend((file_path, func) for func in functions)
|
||||
|
||||
# If no trace file, return in original order
|
||||
# If no trace file, rank by dependency count if call graph is available
|
||||
if not trace_file_path or not trace_file_path.exists():
|
||||
if call_graph is not None:
|
||||
return self.rank_by_dependency_count(all_functions, call_graph)
|
||||
logger.debug("No trace file available, using original function order")
|
||||
return all_functions
|
||||
|
||||
|
|
@ -494,6 +511,19 @@ class Optimizer:
|
|||
else:
|
||||
return globally_ranked
|
||||
|
||||
def rank_by_dependency_count(
|
||||
self, all_functions: list[tuple[Path, FunctionToOptimize]], call_graph: DependencyResolver
|
||||
) -> list[tuple[Path, FunctionToOptimize]]:
|
||||
file_to_qns: dict[Path, set[str]] = defaultdict(set)
|
||||
for file_path, func in all_functions:
|
||||
file_to_qns[file_path].add(func.qualified_name)
|
||||
callee_counts = call_graph.count_callees_per_function(dict(file_to_qns))
|
||||
ranked = sorted(
|
||||
enumerate(all_functions), key=lambda x: (-callee_counts.get((x[1][0], x[1][1].qualified_name), 0), x[0])
|
||||
)
|
||||
logger.debug(f"Ranked {len(ranked)} functions by dependency count (most complex first)")
|
||||
return [item for _, item in ranked]
|
||||
|
||||
def run(self) -> None:
|
||||
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
|
||||
|
||||
|
|
@ -536,16 +566,33 @@ class Optimizer:
|
|||
if self.args.all:
|
||||
three_min_in_ns = int(1.8e11)
|
||||
console.rule()
|
||||
pr_message = (
|
||||
"\nCodeflash will keep opening pull requests as it finds optimizations." if not self.args.no_pr else ""
|
||||
)
|
||||
logger.info(
|
||||
f"It might take about {humanize_runtime(num_optimizable_functions * three_min_in_ns)} to fully optimize this project.{pr_message}"
|
||||
f"It might take about {humanize_runtime(num_optimizable_functions * three_min_in_ns)} to fully optimize this project."
|
||||
)
|
||||
if not self.args.no_pr:
|
||||
logger.info("Codeflash will keep opening pull requests as it finds optimizations.")
|
||||
console.rule()
|
||||
|
||||
function_benchmark_timings, total_benchmark_timings = self.run_benchmarks(
|
||||
file_to_funcs_to_optimize, num_optimizable_functions
|
||||
)
|
||||
|
||||
# Create a language-specific dependency resolver (e.g. Jedi-based call graph for Python)
|
||||
# Skip in CI — the cache DB doesn't persist between runs on ephemeral runners
|
||||
lang_support = current_language_support()
|
||||
resolver = None
|
||||
# CURRENTLY DISABLED: The resolver is currently not used for anything until i clean up the repo structure for python
|
||||
# if lang_support and not env_utils.is_ci():
|
||||
# resolver = lang_support.create_dependency_resolver(self.args.project_root)
|
||||
|
||||
# if resolver is not None and lang_support is not None and file_to_funcs_to_optimize:
|
||||
# supported_exts = lang_support.file_extensions
|
||||
# source_files = [f for f in file_to_funcs_to_optimize if f.suffix in supported_exts]
|
||||
# with call_graph_live_display(len(source_files), project_root=self.args.project_root) as on_progress:
|
||||
# resolver.build_index(source_files, on_progress=on_progress)
|
||||
# console.rule()
|
||||
# call_graph_summary(resolver, file_to_funcs_to_optimize)
|
||||
|
||||
optimizations_found: int = 0
|
||||
self.test_cfg.concolic_test_root_dir = Path(
|
||||
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
|
||||
|
|
@ -557,11 +604,13 @@ class Optimizer:
|
|||
return
|
||||
|
||||
function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize)
|
||||
if self.args.all:
|
||||
if self.args.all and not getattr(self.args, "agent", False):
|
||||
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root)
|
||||
|
||||
# GLOBAL RANKING: Rank all functions together before optimizing
|
||||
globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path)
|
||||
globally_ranked_functions = self.rank_all_functions_globally(
|
||||
file_to_funcs_to_optimize, trace_file_path, call_graph=resolver
|
||||
)
|
||||
# Cache for module preparation (avoid re-parsing same files)
|
||||
prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {}
|
||||
|
||||
|
|
@ -593,6 +642,7 @@ class Optimizer:
|
|||
total_benchmark_timings=total_benchmark_timings,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=original_module_path,
|
||||
call_graph=resolver,
|
||||
)
|
||||
if function_optimizer is None:
|
||||
continue
|
||||
|
|
@ -608,7 +658,7 @@ class Optimizer:
|
|||
if is_successful(best_optimization):
|
||||
optimizations_found += 1
|
||||
# create a diff patch for successful optimization
|
||||
if self.current_worktree:
|
||||
if self.current_worktree and not is_subagent_mode():
|
||||
best_opt = best_optimization.unwrap()
|
||||
read_writable_code = best_opt.code_context.read_writable_code
|
||||
relative_file_paths = [
|
||||
|
|
@ -641,7 +691,12 @@ class Optimizer:
|
|||
self.functions_checkpoint.cleanup()
|
||||
if hasattr(self.args, "command") and self.args.command == "optimize":
|
||||
self.cleanup_replay_tests()
|
||||
if optimizations_found == 0:
|
||||
if is_subagent_mode():
|
||||
if optimizations_found == 0:
|
||||
import sys
|
||||
|
||||
sys.stdout.write("<codeflash-summary>No optimizations found.</codeflash-summary>\n")
|
||||
elif optimizations_found == 0:
|
||||
logger.info("❌ No optimizations found.")
|
||||
elif self.args.all:
|
||||
logger.info("✨ All functions have been optimized! ✨")
|
||||
|
|
@ -651,6 +706,9 @@ class Optimizer:
|
|||
else:
|
||||
logger.warning("⚠️ Failed to send completion email. Status")
|
||||
finally:
|
||||
if resolver is not None:
|
||||
resolver.close()
|
||||
|
||||
if function_optimizer:
|
||||
function_optimizer.cleanup_generated_files()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ import git
|
|||
from codeflash.api import cfapi
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_replacer import is_zero_diff
|
||||
from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name
|
||||
from codeflash.code_utils.github_utils import github_pr_url
|
||||
from codeflash.code_utils.tabulate import tabulate
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
from codeflash.github.PrComment import FileDiffContent, PrComment
|
||||
from codeflash.languages.python.static_analysis.code_replacer import is_zero_diff
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -126,10 +126,10 @@ def existing_tests_source_for(
|
|||
tests_dir_name = test_cfg.tests_project_rootdir.name
|
||||
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
||||
# Module path includes "tests." - use project root parent
|
||||
instrumented_abs_path = (test_cfg.tests_project_rootdir.parent / file_path).resolve()
|
||||
instrumented_abs_path = test_cfg.tests_project_rootdir.parent / file_path
|
||||
else:
|
||||
# Module path doesn't include tests dir - use tests root directly
|
||||
instrumented_abs_path = (test_cfg.tests_project_rootdir / file_path).resolve()
|
||||
instrumented_abs_path = test_cfg.tests_project_rootdir / file_path
|
||||
logger.debug(f"[PR-DEBUG] Looking up: {instrumented_abs_path}")
|
||||
logger.debug(f"[PR-DEBUG] Available keys: {list(instrumented_to_original.keys())[:3]}")
|
||||
# Try to map instrumented path to original path
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
|
||||
import sentry_sdk
|
||||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
from sentry_sdk.integrations.stdlib import StdlibIntegration
|
||||
|
||||
|
||||
def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
|
||||
|
|
@ -16,12 +17,8 @@ def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
|
|||
sentry_sdk.init(
|
||||
dsn="https://4b9a1902f9361b48c04376df6483bc96@o4506833230561280.ingest.sentry.io/4506833262477312",
|
||||
integrations=[sentry_logging],
|
||||
# Set traces_sample_rate to 1.0 to capture 100%
|
||||
# of transactions for performance monitoring.
|
||||
traces_sample_rate=1.0,
|
||||
# Set profiles_sample_rate to 1.0 to profile 100%
|
||||
# of sampled transactions.
|
||||
# We recommend adjusting this value in production.
|
||||
profiles_sample_rate=1.0,
|
||||
disabled_integrations=[StdlibIntegration],
|
||||
traces_sample_rate=0,
|
||||
profiles_sample_rate=0,
|
||||
ignore_errors=[KeyboardInterrupt],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
|
@ -9,15 +10,17 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.concolic_utils import clean_concolic_tests, is_valid_concolic_test
|
||||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.code_utils.static_analysis import has_typed_parameters
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests, is_valid_concolic_test
|
||||
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
CROSSHAIR_AVAILABLE = importlib.util.find_spec("crosshair") is not None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
|
|
@ -52,6 +55,10 @@ def generate_concolic_tests(
|
|||
logger.debug("Skipping concolic test generation for non-Python languages (CrossHair is Python-only)")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if not CROSSHAIR_AVAILABLE:
|
||||
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if is_LSP_enabled():
|
||||
logger.debug("Skipping concolic test generation in LSP mode")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sentry_sdk
|
|||
from coverage.exceptions import NoDataError
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.coverage_utils import (
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import (
|
||||
build_fully_qualified_name,
|
||||
extract_dependent_function,
|
||||
generate_candidates,
|
||||
|
|
|
|||
|
|
@ -47,8 +47,24 @@ def parse_func(file_path: Path) -> XMLParser:
|
|||
return parse(file_path, xml_parser)
|
||||
|
||||
|
||||
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
|
||||
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
|
||||
matches_re_start = re.compile(
|
||||
r"!\$######([^:]*)" # group 1: module path
|
||||
r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty
|
||||
r"([^.:]*)" # group 3: test function name
|
||||
r":([^:]*)" # group 4: function being tested
|
||||
r":([^:]*)" # group 5: loop index
|
||||
r":([^#]*)" # group 6: iteration id
|
||||
r"######\$!\n"
|
||||
)
|
||||
matches_re_end = re.compile(
|
||||
r"!######([^:]*)" # group 1: module path
|
||||
r":((?:[^:.]*\.)*)" # group 2: class prefix with trailing dot, or empty
|
||||
r"([^.:]*)" # group 3: test function name
|
||||
r":([^:]*)" # group 4: function being tested
|
||||
r":([^:]*)" # group 5: loop index
|
||||
r":([^#]*)" # group 6: iteration_id or iteration_id:runtime
|
||||
r"######!"
|
||||
)
|
||||
|
||||
|
||||
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
|
||||
|
|
@ -893,7 +909,6 @@ def merge_test_results(
|
|||
return merged_test_results
|
||||
|
||||
|
||||
FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+")
|
||||
TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$")
|
||||
|
||||
|
||||
|
|
@ -903,7 +918,7 @@ def parse_test_failures_from_stdout(stdout: str) -> dict[str, str]:
|
|||
start = end = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if FAILURES_HEADER_RE.search(line.strip()):
|
||||
if "= FAILURES =" in line:
|
||||
start = i
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ from codeflash.cli_cmds.console import logger
|
|||
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
|
||||
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
from codeflash.code_utils.coverage_utils import prepare_coverage_files
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
|
||||
from codeflash.languages.registry import get_language_support, get_language_support_by_framework
|
||||
from codeflash.models.models import TestFiles, TestType
|
||||
|
||||
|
|
|
|||
|
|
@ -158,6 +158,10 @@ class TestConfig:
|
|||
_language: Optional[str] = None # Language identifier for multi-language support
|
||||
js_project_root: Optional[Path] = None # JavaScript project root (directory containing package.json)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.project_root_path = self.project_root_path.resolve()
|
||||
self.tests_project_rootdir = self.tests_project_rootdir.resolve()
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Returns the appropriate test framework based on language.
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.20.0"
|
||||
__version__ = "0.20.1"
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ codeflash/code_utils/__init__.py
|
|||
codeflash/code_utils/time_utils.py
|
||||
codeflash/code_utils/env_utils.py
|
||||
codeflash/code_utils/config_consts.py
|
||||
codeflash/code_utils/static_analysis.py
|
||||
codeflash/code_utils/edit_generated_tests.py
|
||||
codeflash/languages/python/static_analysis/static_analysis.py
|
||||
codeflash/languages/python/static_analysis/edit_generated_tests.py
|
||||
codeflash/cli_cmds/console_constants.py
|
||||
codeflash/cli_cmds/logging_config.py
|
||||
codeflash/cli_cmds/__init__.py
|
||||
|
|
|
|||
|
|
@ -39,13 +39,14 @@ dependencies = [
|
|||
"dill>=0.3.8",
|
||||
"rich>=13.8.1",
|
||||
"lxml>=5.3.0",
|
||||
"crosshair-tool>=0.0.78",
|
||||
"crosshair-tool>=0.0.78; python_version < '3.15'",
|
||||
"coverage>=7.6.4",
|
||||
"line_profiler>=4.2.0",
|
||||
"platformdirs>=4.3.7",
|
||||
"pygls>=2.0.0,<3.0.0",
|
||||
"codeflash-benchmark",
|
||||
"filelock",
|
||||
"filelock>=3.20.3; python_version >= '3.10'",
|
||||
"filelock<3.20.3; python_version < '3.10'",
|
||||
"pytest-asyncio>=0.18.0",
|
||||
]
|
||||
|
||||
|
|
@ -94,7 +95,7 @@ tests = [
|
|||
"xarray>=2024.7.0",
|
||||
"eval_type_backport",
|
||||
"numba>=0.60.0",
|
||||
"tensorflow>=2.20.0",
|
||||
"tensorflow>=2.20.0; python_version >= '3.10'",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
|
|
@ -207,6 +208,8 @@ warn_unreachable = true
|
|||
install_types = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
exclude = ["tests/", "code_to_optimize/", "pie_test_set/", "experiments/"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba"]
|
||||
ignore_missing_imports = true
|
||||
|
|
@ -310,6 +313,9 @@ split-on-trailing-comma = false
|
|||
docstring-code-format = true
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
[tool.ty.src]
|
||||
exclude = ["tests", "code_to_optimize", "pie_test_set", "experiments"]
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "uv-dynamic-versioning"
|
||||
|
||||
|
|
|
|||
27
tessl.json
27
tessl.json
|
|
@ -20,7 +20,7 @@
|
|||
"version": "0.13.0"
|
||||
},
|
||||
"tessl/pypi-pydantic": {
|
||||
"version": "1.10.0"
|
||||
"version": "2.11.0"
|
||||
},
|
||||
"tessl/pypi-humanize": {
|
||||
"version": "4.13.0"
|
||||
|
|
@ -35,7 +35,7 @@
|
|||
"version": "3.4.0"
|
||||
},
|
||||
"tessl/pypi-sentry-sdk": {
|
||||
"version": "1.45.0"
|
||||
"version": "2.36.0"
|
||||
},
|
||||
"tessl/pypi-parameterized": {
|
||||
"version": "0.9.0"
|
||||
|
|
@ -44,10 +44,10 @@
|
|||
"version": "0.4.0"
|
||||
},
|
||||
"tessl/pypi-rich": {
|
||||
"version": "13.9.0"
|
||||
"version": "14.1.0"
|
||||
},
|
||||
"tessl/pypi-lxml": {
|
||||
"version": "5.4.0"
|
||||
"version": "6.0.0"
|
||||
},
|
||||
"tessl/pypi-crosshair-tool": {
|
||||
"version": "0.0.0"
|
||||
|
|
@ -64,17 +64,20 @@
|
|||
"tessl/pypi-filelock": {
|
||||
"version": "3.19.0"
|
||||
},
|
||||
"codeflash/codeflash-rules": {
|
||||
"version": "0.1.0"
|
||||
"tessl/pypi-ipython": {
|
||||
"version": "9.5.0"
|
||||
},
|
||||
"codeflash/codeflash-docs": {
|
||||
"version": "0.1.0"
|
||||
"tessl/pypi-mypy": {
|
||||
"version": "1.17.0"
|
||||
},
|
||||
"codeflash/codeflash-skills": {
|
||||
"version": "0.2.0"
|
||||
"tessl/pypi-ty": {
|
||||
"version": "0.0.0"
|
||||
},
|
||||
"tessl-labs/tessl-skill-eval-scenarios": {
|
||||
"version": "0.0.5"
|
||||
"tessl/pypi-types-jsonschema": {
|
||||
"version": "3.2.0"
|
||||
},
|
||||
"tessl/pypi-uv": {
|
||||
"version": "0.8.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.concolic_utils import AssertCleanup, is_valid_concolic_test
|
||||
from codeflash.languages.python.static_analysis.concolic_utils import AssertCleanup, is_valid_concolic_test
|
||||
|
||||
|
||||
class TestFirstTopLevelArg:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from codeflash.code_utils.coverage_utils import build_fully_qualified_name, extract_dependent_function
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import build_fully_qualified_name, extract_dependent_function
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown
|
||||
from codeflash.verification.coverage_utils import CoverageUtils
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ from pathlib import Path
|
|||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.code_utils.code_extractor import (
|
||||
from codeflash.languages.python.static_analysis.code_extractor import (
|
||||
DottedImportCollector,
|
||||
add_needed_imports_from_module,
|
||||
find_preexisting_objects,
|
||||
resolve_star_import,
|
||||
)
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ import jedi
|
|||
import tiktoken
|
||||
from jedi.api.classes import Name
|
||||
from pydantic.dataclasses import dataclass
|
||||
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton
|
||||
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ import tiktoken
|
|||
from jedi.api.classes import Name
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_code, get_code_no_skeleton
|
||||
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests
|
||||
from codeflash.languages.python.static_analysis.edit_generated_tests import add_runtime_comments_to_generated_tests
|
||||
from codeflash.models.models import (
|
||||
FunctionTestInvocation,
|
||||
GeneratedTests,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_type_names_from_annotation,
|
||||
|
|
@ -768,11 +768,204 @@ class HelperClass:
|
|||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_3(tmp_path: Path) -> None:
|
||||
def test_example_class_token_limit_1(tmp_path: Path) -> None:
|
||||
docstring_filler = " ".join(
|
||||
["This is a long docstring that will be used to fill up the token limit." for _ in range(4000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.
|
||||
{docstring_filler}\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
# Create a temporary Python file using pytest's tmp_path fixture
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
def __repr__(self):
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_2(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method. \"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
x = '{string_filler}'
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
# Create a temporary Python file using pytest's tmp_path fixture
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
"""A class with a helper method. """
|
||||
|
||||
class HelperClass:
|
||||
"""A helper class for MyClass."""
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the HelperClass."""
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
'''
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_3(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(4000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method. \"\"\"
|
||||
def __init__(self):
|
||||
|
|
@ -820,7 +1013,7 @@ class HelperClass:
|
|||
|
||||
def test_example_class_token_limit_4(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(4000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
|
|
@ -979,7 +1172,12 @@ def test_repo_helper() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
path_to_globals = project_root / "globals.py"
|
||||
expected_read_write_context = f"""
|
||||
```python:{path_to_globals.relative_to(project_root)}
|
||||
# Define a global variable
|
||||
API_URL = "https://api.example.com/data"
|
||||
```
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
|
|
@ -1072,7 +1270,12 @@ def test_repo_helper_of_helper() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
path_to_globals = project_root / "globals.py"
|
||||
expected_read_write_context = f"""
|
||||
```python:{path_to_globals.relative_to(project_root)}
|
||||
# Define a global variable
|
||||
API_URL = "https://api.example.com/data"
|
||||
```
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
|
@ -1799,6 +2002,8 @@ class Calculator:
|
|||
"""
|
||||
expected_read_only_context = """
|
||||
```python:utility_module.py
|
||||
import sys
|
||||
|
||||
DEFAULT_PRECISION = "medium"
|
||||
|
||||
# Try-except block with variable definitions
|
||||
|
|
@ -1809,6 +2014,17 @@ except ImportError:
|
|||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
|
||||
# Nested if-else with variable definitions
|
||||
if sys.platform.startswith('win'):
|
||||
# Used variable in outer if
|
||||
SYSTEM_TYPE = "windows"
|
||||
elif sys.platform.startswith('linux'):
|
||||
# Used variable in outer elif
|
||||
SYSTEM_TYPE = "linux"
|
||||
else:
|
||||
# Used variable in outer else
|
||||
SYSTEM_TYPE = "other"
|
||||
|
||||
# Function that will be used in the main code
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
|
|
@ -2015,6 +2231,8 @@ def get_system_details():
|
|||
relative_path = file_path.relative_to(project_root)
|
||||
expected_read_write_context = f"""
|
||||
```python:utility_module.py
|
||||
import sys
|
||||
|
||||
DEFAULT_PRECISION = "medium"
|
||||
|
||||
# Try-except block with variable definitions
|
||||
|
|
@ -2025,6 +2243,17 @@ except ImportError:
|
|||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
|
||||
# Nested if-else with variable definitions
|
||||
if sys.platform.startswith('win'):
|
||||
# Used variable in outer if
|
||||
SYSTEM_TYPE = "windows"
|
||||
elif sys.platform.startswith('linux'):
|
||||
# Used variable in outer elif
|
||||
SYSTEM_TYPE = "linux"
|
||||
else:
|
||||
# Used variable in outer else
|
||||
SYSTEM_TYPE = "other"
|
||||
|
||||
# Function that will be used in the main code
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
|
|
@ -2065,6 +2294,8 @@ class Calculator:
|
|||
"""
|
||||
expected_read_only_context = """
|
||||
```python:utility_module.py
|
||||
import sys
|
||||
|
||||
DEFAULT_PRECISION = "medium"
|
||||
|
||||
# Try-except block with variable definitions
|
||||
|
|
@ -2074,6 +2305,17 @@ try:
|
|||
except ImportError:
|
||||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
|
||||
# Nested if-else with variable definitions
|
||||
if sys.platform.startswith('win'):
|
||||
# Used variable in outer if
|
||||
SYSTEM_TYPE = "windows"
|
||||
elif sys.platform.startswith('linux'):
|
||||
# Used variable in outer elif
|
||||
SYSTEM_TYPE = "linux"
|
||||
else:
|
||||
# Used variable in outer else
|
||||
SYSTEM_TYPE = "other"
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
|
|
@ -2629,7 +2871,7 @@ def test_global_function_collector():
|
|||
"""Test GlobalFunctionCollector correctly collects module-level function definitions."""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.code_utils.code_extractor import GlobalFunctionCollector
|
||||
from codeflash.languages.python.static_analysis.code_extractor import GlobalFunctionCollector
|
||||
|
||||
source_code = """
|
||||
# Module-level functions
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.languages.python.static_analysis.code_extractor import add_needed_imports_from_module
|
||||
|
||||
|
||||
def test_add_needed_imports_with_none_aliases():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
|
@ -8,39 +7,23 @@ from pathlib import Path
|
|||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.code_utils.code_extractor import delete___future___aliased_imports, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import (
|
||||
from codeflash.languages.python.static_analysis.code_extractor import delete___future___aliased_imports, find_preexisting_objects
|
||||
from codeflash.languages.python.static_analysis.code_replacer import (
|
||||
AddRequestArgument,
|
||||
AutouseFixtureModifier,
|
||||
OptimFunctionCollector,
|
||||
PytestMarkAdder,
|
||||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
replace_functions_in_file,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JediDefinition:
|
||||
type: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FakeFunctionSource:
|
||||
file_path: Path
|
||||
qualified_name: str
|
||||
fully_qualified_name: str
|
||||
only_function_name: str
|
||||
source_code: str
|
||||
jedi_definition: JediDefinition
|
||||
|
||||
|
||||
class Args:
|
||||
disable_imports_sorting = True
|
||||
formatter_cmds = ["disabled"]
|
||||
|
|
@ -1137,7 +1120,7 @@ class TestResults(BaseModel):
|
|||
preexisting_objects = find_preexisting_objects(original_code)
|
||||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
FunctionSource(
|
||||
file_path=Path(
|
||||
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
|
||||
),
|
||||
|
|
@ -1145,7 +1128,7 @@ class TestResults(BaseModel):
|
|||
fully_qualified_name="codeflash.verification.test_results.TestType",
|
||||
only_function_name="TestType",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="class"),
|
||||
definition_type="class",
|
||||
)
|
||||
]
|
||||
|
||||
|
|
@ -1160,7 +1143,7 @@ class TestResults(BaseModel):
|
|||
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
if helper_function.definition_type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
|
|
@ -1352,21 +1335,21 @@ def cosine_similarity_top_k(
|
|||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
FunctionSource(
|
||||
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
||||
qualified_name="Matrix",
|
||||
fully_qualified_name="code_to_optimize.math_utils.Matrix",
|
||||
only_function_name="Matrix",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="class"),
|
||||
definition_type="class",
|
||||
),
|
||||
FakeFunctionSource(
|
||||
FunctionSource(
|
||||
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
||||
qualified_name="cosine_similarity",
|
||||
fully_qualified_name="code_to_optimize.math_utils.cosine_similarity",
|
||||
only_function_name="cosine_similarity",
|
||||
source_code="",
|
||||
jedi_definition=JediDefinition(type="function"),
|
||||
definition_type="function",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -1425,7 +1408,7 @@ def cosine_similarity_top_k(
|
|||
)
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
if helper_function.definition_type != "class":
|
||||
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
||||
new_helper_code: str = replace_functions_and_add_imports(
|
||||
|
|
@ -3492,142 +3475,6 @@ def hydrate_input_text_actions_with_field_names(
|
|||
assert new_code == expected
|
||||
|
||||
|
||||
# OptimFunctionCollector async function tests
|
||||
def test_optim_function_collector_with_async_functions():
|
||||
"""Test OptimFunctionCollector correctly collects async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def sync_function():
|
||||
return "sync"
|
||||
|
||||
async def async_function():
|
||||
return "async"
|
||||
|
||||
class TestClass:
|
||||
def sync_method(self):
|
||||
return "sync_method"
|
||||
|
||||
async def async_method(self):
|
||||
return "async_method"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names={
|
||||
(None, "sync_function"),
|
||||
(None, "async_function"),
|
||||
("TestClass", "sync_method"),
|
||||
("TestClass", "async_method"),
|
||||
},
|
||||
preexisting_objects=None,
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect both sync and async functions
|
||||
assert len(collector.modified_functions) == 4
|
||||
assert (None, "sync_function") in collector.modified_functions
|
||||
assert (None, "async_function") in collector.modified_functions
|
||||
assert ("TestClass", "sync_method") in collector.modified_functions
|
||||
assert ("TestClass", "async_method") in collector.modified_functions
|
||||
|
||||
|
||||
def test_optim_function_collector_new_async_functions():
|
||||
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def existing_function():
|
||||
return "existing"
|
||||
|
||||
async def new_async_function():
|
||||
return "new_async"
|
||||
|
||||
def new_sync_function():
|
||||
return "new_sync"
|
||||
|
||||
class ExistingClass:
|
||||
async def new_class_async_method(self):
|
||||
return "new_class_async"
|
||||
"""
|
||||
|
||||
# Only existing_function is in preexisting objects
|
||||
preexisting_objects = {("existing_function", ())}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=set(), # Not looking for specific functions
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should identify new functions (both sync and async)
|
||||
assert len(collector.new_functions) == 2
|
||||
function_names = [func.name.value for func in collector.new_functions]
|
||||
assert "new_async_function" in function_names
|
||||
assert "new_sync_function" in function_names
|
||||
|
||||
# Should identify new class methods
|
||||
assert "ExistingClass" in collector.new_class_functions
|
||||
assert len(collector.new_class_functions["ExistingClass"]) == 1
|
||||
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
|
||||
|
||||
|
||||
def test_optim_function_collector_mixed_scenarios():
|
||||
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global functions
|
||||
def global_sync():
|
||||
pass
|
||||
|
||||
async def global_async():
|
||||
pass
|
||||
|
||||
class ParentClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
pass
|
||||
|
||||
async def async_method(self):
|
||||
pass
|
||||
|
||||
class ChildClass:
|
||||
async def child_async_method(self):
|
||||
pass
|
||||
|
||||
def child_sync_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
# Looking for specific functions
|
||||
function_names = {
|
||||
(None, "global_sync"),
|
||||
(None, "global_async"),
|
||||
("ParentClass", "sync_method"),
|
||||
("ParentClass", "async_method"),
|
||||
("ChildClass", "child_async_method"),
|
||||
}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(function_names=function_names, preexisting_objects=None)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect all specified functions (mix of sync and async)
|
||||
assert len(collector.modified_functions) == 5
|
||||
assert (None, "global_sync") in collector.modified_functions
|
||||
assert (None, "global_async") in collector.modified_functions
|
||||
assert ("ParentClass", "sync_method") in collector.modified_functions
|
||||
assert ("ParentClass", "async_method") in collector.modified_functions
|
||||
assert ("ChildClass", "child_async_method") in collector.modified_functions
|
||||
|
||||
# Should collect __init__ method
|
||||
assert "ParentClass" in collector.modified_init_functions
|
||||
|
||||
|
||||
def test_is_zero_diff_async_sleep():
|
||||
original_code = """
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ from codeflash.code_utils.code_utils import (
|
|||
path_belongs_to_site_packages,
|
||||
validate_python_code,
|
||||
)
|
||||
from codeflash.code_utils.concolic_utils import clean_concolic_tests
|
||||
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
|
||||
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_get_run_tmp_file() -> Generator[MagicMock, None, None]:
|
||||
with patch("codeflash.code_utils.coverage_utils.get_run_tmp_file") as mock:
|
||||
with patch("codeflash.languages.python.static_analysis.coverage_utils.get_run_tmp_file") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -151,10 +151,9 @@ def test_class_method_dependencies() -> None:
|
|||
# The code_context above should have the topologicalSortUtil function in it
|
||||
assert len(code_context.helper_functions) == 1
|
||||
assert (
|
||||
code_context.helper_functions[0].jedi_definition.full_name
|
||||
== "test_function_dependencies.Graph.topologicalSortUtil"
|
||||
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
|
||||
)
|
||||
assert code_context.helper_functions[0].jedi_definition.name == "topologicalSortUtil"
|
||||
assert code_context.helper_functions[0].only_function_name == "topologicalSortUtil"
|
||||
assert (
|
||||
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_code
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from codeflash.code_utils.instrument_existing_tests import (
|
|||
FunctionImportedAsVisitor,
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
from codeflash.code_utils.line_profile_utils import add_decorator_imports
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
from codeflash.code_utils.code_extractor import is_numerical_code
|
||||
from codeflash.languages.python.static_analysis.code_extractor import is_numerical_code
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestBasicNumpyUsage:
|
||||
"""Test basic numpy library detection (with numba available)."""
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ def func(x):
|
|||
assert is_numerical_code(code, "func") is True
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestNumpySubmodules:
|
||||
"""Test numpy submodule imports (with numba available)."""
|
||||
|
||||
|
|
@ -265,7 +265,7 @@ def func(x):
|
|||
assert is_numerical_code(code, "func") is True
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestScipyUsage:
|
||||
"""Test SciPy library detection (with numba available)."""
|
||||
|
||||
|
|
@ -302,7 +302,7 @@ def func(f, x0):
|
|||
assert is_numerical_code(code, "func") is True
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestMathUsage:
|
||||
"""Test math standard library detection (with numba available)."""
|
||||
|
||||
|
|
@ -331,7 +331,7 @@ def calculate(x):
|
|||
assert is_numerical_code(code, "calculate") is True
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestClassMethods:
|
||||
"""Test detection in class methods, staticmethods, and classmethods (with numba available)."""
|
||||
|
||||
|
|
@ -472,7 +472,7 @@ def func():
|
|||
assert is_numerical_code(code, "func") is False
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and special scenarios (with numba available)."""
|
||||
|
||||
|
|
@ -535,7 +535,7 @@ async def async_process(x):
|
|||
assert is_numerical_code(code, "async_process") is False
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestStarImports:
|
||||
"""Test handling of star imports (with numba available).
|
||||
|
||||
|
|
@ -575,7 +575,7 @@ def func(x):
|
|||
assert is_numerical_code(code, "func") is False
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestNestedUsage:
|
||||
"""Test nested numerical library usage patterns (with numba available)."""
|
||||
|
||||
|
|
@ -618,7 +618,7 @@ def func(x):
|
|||
assert is_numerical_code(code, "func") is True
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestMultipleLibraries:
|
||||
"""Test code using multiple numerical libraries (with numba available)."""
|
||||
|
||||
|
|
@ -643,7 +643,7 @@ def analyze(data):
|
|||
assert is_numerical_code(code, "analyze") is True
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestQualifiedNames:
|
||||
"""Test various qualified name patterns (with numba available)."""
|
||||
|
||||
|
|
@ -689,7 +689,7 @@ class ClassB:
|
|||
assert is_numerical_code(code, "ClassB.method") is False
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", True)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", True)
|
||||
class TestEmptyFunctionName:
|
||||
"""Test behavior when function_name is empty/None.
|
||||
|
||||
|
|
@ -807,7 +807,7 @@ def broken(
|
|||
assert is_numerical_code(code, "") is False
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", False)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", False)
|
||||
class TestEmptyFunctionNameWithoutNumba:
|
||||
"""Test empty function_name behavior when numba is NOT available.
|
||||
|
||||
|
|
@ -886,7 +886,7 @@ from scipy import stats
|
|||
assert is_numerical_code(code, "") is False
|
||||
|
||||
|
||||
@patch("codeflash.code_utils.code_extractor.has_numba", False)
|
||||
@patch("codeflash.languages.python.static_analysis.code_extractor.has_numba", False)
|
||||
class TestNumbaNotAvailable:
|
||||
"""Test behavior when numba is NOT available in the environment.
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from codeflash.languages.javascript.find_references import (
|
|||
find_references,
|
||||
)
|
||||
from codeflash.languages.base import Language, FunctionInfo, ReferenceInfo
|
||||
from codeflash.code_utils.code_extractor import _format_references_as_markdown
|
||||
from codeflash.languages.python.static_analysis.code_extractor import _format_references_as_markdown
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
|
|
|
|||
189
tests/test_parse_test_output_regex.py
Normal file
189
tests/test_parse_test_output_regex.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""Tests for the regex patterns and string matching in parse_test_output.py."""
|
||||
|
||||
from codeflash.verification.parse_test_output import (
|
||||
matches_re_end,
|
||||
matches_re_start,
|
||||
parse_test_failures_from_stdout,
|
||||
)
|
||||
|
||||
|
||||
# --- matches_re_start tests ---
|
||||
|
||||
|
||||
class TestMatchesReStart:
|
||||
def test_simple_no_class(self) -> None:
|
||||
s = "!$######tests.test_foo:test_bar:target_func:1:abc######$!\n"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("tests.test_foo", "", "test_bar", "target_func", "1", "abc")
|
||||
|
||||
def test_with_class(self) -> None:
|
||||
s = "!$######tests.test_foo:MyClass.test_bar:target_func:1:abc######$!\n"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("tests.test_foo", "MyClass.", "test_bar", "target_func", "1", "abc")
|
||||
|
||||
def test_nested_class(self) -> None:
|
||||
s = "!$######a.b.c:A.B.test_x:func:3:id123######$!\n"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("a.b.c", "A.B.", "test_x", "func", "3", "id123")
|
||||
|
||||
def test_empty_class_and_function(self) -> None:
|
||||
s = "!$######mod::func:0:iter######$!\n"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("mod", "", "", "func", "0", "iter")
|
||||
|
||||
def test_embedded_in_stdout(self) -> None:
|
||||
s = "some output\n!$######mod:test_fn:f:1:x######$!\nmore output\n"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("mod", "", "test_fn", "f", "1", "x")
|
||||
|
||||
def test_multiple_matches(self) -> None:
|
||||
s = (
|
||||
"!$######m1:C1.fn1:t1:1:a######$!\n"
|
||||
"!$######m2:fn2:t2:2:b######$!\n"
|
||||
)
|
||||
matches = list(matches_re_start.finditer(s))
|
||||
assert len(matches) == 2
|
||||
assert matches[0].groups() == ("m1", "C1.", "fn1", "t1", "1", "a")
|
||||
assert matches[1].groups() == ("m2", "", "fn2", "t2", "2", "b")
|
||||
|
||||
def test_no_match_without_newline(self) -> None:
|
||||
s = "!$######mod:test_fn:f:1:x######$!"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is None
|
||||
|
||||
def test_dots_in_module_path(self) -> None:
|
||||
s = "!$######a.b.c.d.e:test_fn:f:1:x######$!\n"
|
||||
m = matches_re_start.search(s)
|
||||
assert m is not None
|
||||
assert m.group(1) == "a.b.c.d.e"
|
||||
|
||||
|
||||
# --- matches_re_end tests ---
|
||||
|
||||
|
||||
class TestMatchesReEnd:
|
||||
def test_simple_no_class_with_runtime(self) -> None:
|
||||
s = "!######tests.test_foo:test_bar:target_func:1:abc:12345######!"
|
||||
m = matches_re_end.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("tests.test_foo", "", "test_bar", "target_func", "1", "abc:12345")
|
||||
|
||||
def test_with_class_no_runtime(self) -> None:
|
||||
s = "!######tests.test_foo:MyClass.test_bar:target_func:1:abc######!"
|
||||
m = matches_re_end.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("tests.test_foo", "MyClass.", "test_bar", "target_func", "1", "abc")
|
||||
|
||||
def test_nested_class_with_runtime(self) -> None:
|
||||
s = "!######mod:A.B.test_x:func:3:id123:99999######!"
|
||||
m = matches_re_end.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("mod", "A.B.", "test_x", "func", "3", "id123:99999")
|
||||
|
||||
def test_runtime_colon_preserved_in_group6(self) -> None:
|
||||
"""Group 6 must capture 'iteration_id:runtime' as a single string (colon included)."""
|
||||
s = "!######m:fn:f:1:iter42:98765######!"
|
||||
m = matches_re_end.search(s)
|
||||
assert m is not None
|
||||
assert m.group(6) == "iter42:98765"
|
||||
|
||||
def test_embedded_in_stdout(self) -> None:
|
||||
s = "captured output\n!######mod:test_fn:f:1:x:500######!\nmore"
|
||||
m = matches_re_end.search(s)
|
||||
assert m is not None
|
||||
assert m.groups() == ("mod", "", "test_fn", "f", "1", "x:500")
|
||||
|
||||
|
||||
# --- Start/End pairing (simulates parse_test_xml matching logic) ---
|
||||
|
||||
|
||||
class TestStartEndPairing:
|
||||
def test_paired_markers(self) -> None:
|
||||
stdout = (
|
||||
"!$######mod:Class.test_fn:func:1:iter1######$!\n"
|
||||
"test output here\n"
|
||||
"!######mod:Class.test_fn:func:1:iter1:54321######!"
|
||||
)
|
||||
starts = list(matches_re_start.finditer(stdout))
|
||||
ends = {}
|
||||
for match in matches_re_end.finditer(stdout):
|
||||
groups = match.groups()
|
||||
g5 = groups[5]
|
||||
colon_pos = g5.find(":")
|
||||
if colon_pos != -1:
|
||||
key = groups[:5] + (g5[:colon_pos],)
|
||||
else:
|
||||
key = groups
|
||||
ends[key] = match
|
||||
|
||||
assert len(starts) == 1
|
||||
assert len(ends) == 1
|
||||
# Start and end should pair on the first 5 groups + iteration_id
|
||||
start_groups = starts[0].groups()
|
||||
assert start_groups in ends
|
||||
|
||||
|
||||
# --- parse_test_failures_from_stdout tests ---
|
||||
|
||||
|
||||
class TestParseTestFailuresHeader:
|
||||
def test_standard_pytest_header(self) -> None:
|
||||
stdout = (
|
||||
"..F.\n"
|
||||
"=================================== FAILURES ===================================\n"
|
||||
"_______ test_foo _______\n"
|
||||
"\n"
|
||||
" def test_foo():\n"
|
||||
"> assert False\n"
|
||||
"E AssertionError\n"
|
||||
"\n"
|
||||
"test.py:3: AssertionError\n"
|
||||
"=========================== short test summary info ============================\n"
|
||||
"FAILED test.py::test_foo\n"
|
||||
)
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert "test_foo" in result
|
||||
|
||||
def test_minimal_equals(self) -> None:
|
||||
"""Even a short '= FAILURES =' header should be detected."""
|
||||
stdout = (
|
||||
"= FAILURES =\n"
|
||||
"_______ test_bar _______\n"
|
||||
"\n"
|
||||
" assert False\n"
|
||||
"\n"
|
||||
"test.py:1: AssertionError\n"
|
||||
"= short test summary info =\n"
|
||||
)
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert "test_bar" in result
|
||||
|
||||
def test_no_failures_section(self) -> None:
|
||||
stdout = "....\n4 passed in 0.1s\n"
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert result == {}
|
||||
|
||||
def test_word_failures_without_equals_is_not_matched(self) -> None:
|
||||
"""'FAILURES' without surrounding '=' signs should not trigger the header detection."""
|
||||
stdout = (
|
||||
"FAILURES detected in module\n"
|
||||
"_______ test_baz _______\n"
|
||||
"\n"
|
||||
" assert False\n"
|
||||
)
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert result == {}
|
||||
|
||||
def test_failures_in_test_output_not_matched(self) -> None:
|
||||
"""A test printing 'FAILURES' (no = signs) should not trigger header detection."""
|
||||
stdout = (
|
||||
"Testing FAILURES handling\n"
|
||||
"All good\n"
|
||||
)
|
||||
result = parse_test_failures_from_stdout(stdout)
|
||||
assert result == {}
|
||||
475
tests/test_reference_graph.py
Normal file
475
tests/test_reference_graph.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import IndexResult
|
||||
from codeflash.languages.python.reference_graph import ReferenceGraph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project(tmp_path: Path) -> Path:
|
||||
project_root = tmp_path / "project"
|
||||
project_root.mkdir()
|
||||
return project_root
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_path(tmp_path: Path) -> Path:
|
||||
return tmp_path / "cache.db"
|
||||
|
||||
|
||||
def write_file(project: Path, name: str, content: str) -> Path:
|
||||
fp = project / name
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
return fp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_simple_function_call(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
def caller():
|
||||
return helper()
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
callee_qns = {fs.qualified_name for fs in result_list}
|
||||
assert "helper" in callee_qns
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_cross_file_call(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"utils.py",
|
||||
"""\
|
||||
def utility():
|
||||
return 42
|
||||
""",
|
||||
)
|
||||
write_file(
|
||||
project,
|
||||
"main.py",
|
||||
"""\
|
||||
from utils import utility
|
||||
|
||||
def caller():
|
||||
return utility()
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "main.py": {"caller"}})
|
||||
callee_qns = {fs.qualified_name for fs in result_list}
|
||||
assert "utility" in callee_qns
|
||||
# Should be in the utils.py file
|
||||
callee_files = {fs.file_path.resolve() for fs in result_list if fs.qualified_name == "utility"}
|
||||
assert (project / "utils.py").resolve() in callee_files
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_class_instantiation(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def caller():
|
||||
obj = MyClass()
|
||||
return obj
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
callee_types = {fs.definition_type for fs in result_list}
|
||||
assert "class" in callee_types
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_nested_function_excluded(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def caller():
|
||||
def inner():
|
||||
return 1
|
||||
return inner()
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
assert len(result_list) == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_module_level_not_tracked(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
x = helper()
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
# Module level calls have no enclosing function, so no edges
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"helper"}})
|
||||
# helper itself doesn't call anything
|
||||
assert len(result_list) == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_site_packages_excluded(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
import os
|
||||
|
||||
def caller():
|
||||
return os.path.join("a", "b")
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
# os.path.join is stdlib, should not appear
|
||||
assert len(result_list) == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_empty_file(project: Path, db_path: Path) -> None:
|
||||
write_file(project, "mod.py", "")
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": set()})
|
||||
assert len(result_list) == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_syntax_error_file(project: Path, db_path: Path) -> None:
|
||||
write_file(project, "mod.py", "def broken(\n")
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"broken"}})
|
||||
assert len(result_list) == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Caching tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_caching_no_reindex(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
def caller():
|
||||
return helper()
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
cg.get_callees({project / "mod.py": {"caller"}})
|
||||
# Second call should use in-memory cache (hash unchanged)
|
||||
resolved = str((project / "mod.py").resolve())
|
||||
assert resolved in cg.indexed_file_hashes
|
||||
old_hash = cg.indexed_file_hashes[resolved]
|
||||
cg.get_callees({project / "mod.py": {"caller"}})
|
||||
assert cg.indexed_file_hashes[resolved] == old_hash
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_incremental_update_on_change(project: Path, db_path: Path) -> None:
|
||||
fp = write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
def caller():
|
||||
return helper()
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
assert any(fs.qualified_name == "helper" for fs in result_list)
|
||||
|
||||
# Modify the file — caller no longer calls helper
|
||||
fp.write_text(
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
def new_helper():
|
||||
return 2
|
||||
|
||||
def caller():
|
||||
return new_helper()
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
_, result_list = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
callee_qns = {fs.qualified_name for fs in result_list}
|
||||
assert "new_helper" in callee_qns
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_persistence_across_sessions(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
def caller():
|
||||
return helper()
|
||||
""",
|
||||
)
|
||||
# First session: index the file
|
||||
cg1 = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, result_list = cg1.get_callees({project / "mod.py": {"caller"}})
|
||||
assert any(fs.qualified_name == "helper" for fs in result_list)
|
||||
finally:
|
||||
cg1.close()
|
||||
|
||||
# Second session: should read from DB without re-indexing
|
||||
cg2 = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
assert len(cg2.indexed_file_hashes) == 0 # in-memory cache is empty
|
||||
_, result_list = cg2.get_callees({project / "mod.py": {"caller"}})
|
||||
assert any(fs.qualified_name == "helper" for fs in result_list)
|
||||
finally:
|
||||
cg2.close()
|
||||
|
||||
|
||||
def test_build_index_with_progress(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"a.py",
|
||||
"""\
|
||||
def helper_a():
|
||||
return 1
|
||||
|
||||
def caller_a():
|
||||
return helper_a()
|
||||
""",
|
||||
)
|
||||
write_file(
|
||||
project,
|
||||
"b.py",
|
||||
"""\
|
||||
from a import helper_a
|
||||
|
||||
def caller_b():
|
||||
return helper_a()
|
||||
""",
|
||||
)
|
||||
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
progress_calls: list[IndexResult] = []
|
||||
files = [project / "a.py", project / "b.py"]
|
||||
cg.build_index(files, on_progress=progress_calls.append)
|
||||
|
||||
# Callback fired once per file
|
||||
assert len(progress_calls) == 2
|
||||
|
||||
# Verify IndexResult fields for freshly indexed files
|
||||
for result in progress_calls:
|
||||
assert isinstance(result, IndexResult)
|
||||
assert not result.error
|
||||
assert not result.cached
|
||||
assert result.num_edges > 0
|
||||
assert len(result.edges) == result.num_edges
|
||||
assert result.cross_file_edges >= 0
|
||||
|
||||
# Files are now indexed — get_callees should return correct results
|
||||
_, result_list = cg.get_callees({project / "a.py": {"caller_a"}})
|
||||
callee_qns = {fs.qualified_name for fs in result_list}
|
||||
assert "helper_a" in callee_qns
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_build_index_cached_results(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"a.py",
|
||||
"""\
|
||||
def helper_a():
|
||||
return 1
|
||||
|
||||
def caller_a():
|
||||
return helper_a()
|
||||
""",
|
||||
)
|
||||
write_file(
|
||||
project,
|
||||
"b.py",
|
||||
"""\
|
||||
from a import helper_a
|
||||
|
||||
def caller_b():
|
||||
return helper_a()
|
||||
""",
|
||||
)
|
||||
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
files = [project / "a.py", project / "b.py"]
|
||||
# First pass — fresh indexing
|
||||
cg.build_index(files)
|
||||
|
||||
# Second pass — should all be cached
|
||||
cached_results: list[IndexResult] = []
|
||||
cg.build_index(files, on_progress=cached_results.append)
|
||||
|
||||
assert len(cached_results) == 2
|
||||
for result in cached_results:
|
||||
assert result.cached
|
||||
assert not result.error
|
||||
assert result.num_edges == 0
|
||||
assert result.edges == ()
|
||||
assert result.cross_file_edges == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_cross_file_edges_tracked(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"utils.py",
|
||||
"""\
|
||||
def utility():
|
||||
return 42
|
||||
""",
|
||||
)
|
||||
write_file(
|
||||
project,
|
||||
"main.py",
|
||||
"""\
|
||||
from utils import utility
|
||||
|
||||
def caller():
|
||||
return utility()
|
||||
""",
|
||||
)
|
||||
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
progress_calls: list[IndexResult] = []
|
||||
cg.build_index([project / "utils.py", project / "main.py"], on_progress=progress_calls.append)
|
||||
|
||||
# main.py should have cross-file edges (calls into utils.py)
|
||||
main_result = next(r for r in progress_calls if r.file_path.name == "main.py")
|
||||
assert main_result.cross_file_edges > 0
|
||||
# At least one edge tuple should have is_cross_file=True
|
||||
assert any(is_cross_file for _, _, is_cross_file in main_result.edges)
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_count_callees_per_function(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper_a():
|
||||
return 1
|
||||
|
||||
def helper_b():
|
||||
return 2
|
||||
|
||||
def caller_one():
|
||||
return helper_a() + helper_b()
|
||||
|
||||
def caller_two():
|
||||
return helper_a()
|
||||
|
||||
def leaf():
|
||||
return 42
|
||||
""",
|
||||
)
|
||||
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
cg.build_index([project / "mod.py"])
|
||||
mod_path = project / "mod.py"
|
||||
counts = cg.count_callees_per_function({mod_path: {"caller_one", "caller_two", "leaf"}})
|
||||
assert counts[(mod_path, "caller_one")] == 2
|
||||
assert counts[(mod_path, "caller_two")] == 1
|
||||
assert counts[(mod_path, "leaf")] == 0
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
def test_same_file_edges_not_cross_file(project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
def helper():
|
||||
return 1
|
||||
|
||||
def caller():
|
||||
return helper()
|
||||
""",
|
||||
)
|
||||
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
progress_calls: list[IndexResult] = []
|
||||
cg.build_index([project / "mod.py"], on_progress=progress_calls.append)
|
||||
|
||||
assert len(progress_calls) == 1
|
||||
result = progress_calls[0]
|
||||
assert result.cross_file_edges == 0
|
||||
# All edges should have is_cross_file=False
|
||||
assert all(not is_cross_file for _, _, is_cross_file in result.edges)
|
||||
finally:
|
||||
cg.close()
|
||||
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.languages.python.static_analysis.edit_generated_tests import remove_functions_from_generated_tests
|
||||
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.static_analysis import (
|
||||
from codeflash.languages.python.static_analysis.static_analysis import (
|
||||
FunctionKind,
|
||||
ImportedInternalModuleAnalysis,
|
||||
analyze_imported_modules,
|
||||
|
|
@ -23,10 +23,10 @@ from pathlib import *
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis
|
||||
from codeflash.languages.python.static_analysis.static_analysis import ImportedInternalModuleAnalysis
|
||||
|
||||
def a_function():
|
||||
from codeflash.code_utils.static_analysis import analyze_imported_modules
|
||||
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
|
||||
from returns.result import Failure, Success
|
||||
pass
|
||||
"""
|
||||
|
|
@ -37,8 +37,8 @@ def a_function():
|
|||
expected_imported_module_analysis = [
|
||||
ImportedInternalModuleAnalysis(
|
||||
name="static_analysis",
|
||||
full_name="codeflash.code_utils.static_analysis",
|
||||
file_path=project_root / Path("codeflash/code_utils/static_analysis.py"),
|
||||
full_name="codeflash.languages.python.static_analysis.static_analysis",
|
||||
file_path=project_root / Path("codeflash/languages/python/static_analysis/static_analysis.py"),
|
||||
),
|
||||
ImportedInternalModuleAnalysis(
|
||||
name="mymodule", full_name="tests.mymodule", file_path=project_root / Path("tests/mymodule.py")
|
||||
|
|
|
|||
|
|
@ -918,7 +918,7 @@ class OuterClass:
|
|||
"only_function_name": "global_helper_1",
|
||||
"fully_qualified_name": "main.global_helper_1",
|
||||
"file_path": main_file,
|
||||
"jedi_definition": type("MockJedi", (), {"type": "function"})(),
|
||||
"definition_type": "function",
|
||||
},
|
||||
)(),
|
||||
type(
|
||||
|
|
@ -929,7 +929,7 @@ class OuterClass:
|
|||
"only_function_name": "global_helper_2",
|
||||
"fully_qualified_name": "main.global_helper_2",
|
||||
"file_path": main_file,
|
||||
"jedi_definition": type("MockJedi", (), {"type": "function"})(),
|
||||
"definition_type": "function",
|
||||
},
|
||||
)(),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,108 +0,0 @@
|
|||
# AI Service
|
||||
|
||||
How codeflash communicates with the AI optimization backend.
|
||||
|
||||
## `AiServiceClient` (`api/aiservice.py`)
|
||||
|
||||
The client connects to the AI service at `https://app.codeflash.ai` (or `http://localhost:8000` when `CODEFLASH_AIS_SERVER=local`).
|
||||
|
||||
Authentication uses Bearer token from `get_codeflash_api_key()`. All requests go through `make_ai_service_request()` which handles JSON serialization via Pydantic encoder.
|
||||
|
||||
Timeout: 90s for production, 300s for local.
|
||||
|
||||
## Endpoints
|
||||
|
||||
### `/ai/optimize` — Generate Candidates
|
||||
|
||||
Method: `optimize_code()`
|
||||
|
||||
Sends source code + dependency context to generate optimization candidates.
|
||||
|
||||
Payload:
|
||||
- `source_code` — The read-writable code (markdown format)
|
||||
- `dependency_code` — Read-only context code
|
||||
- `trace_id` — Unique trace ID for the optimization run
|
||||
- `language` — `"python"`, `"javascript"`, or `"typescript"`
|
||||
- `n_candidates` — Number of candidates to generate (controlled by effort level)
|
||||
- `is_async` — Whether the function is async
|
||||
- `is_numerical_code` — Whether the code is numerical (affects optimization strategy)
|
||||
|
||||
Returns: `list[OptimizedCandidate]` with `source=OptimizedCandidateSource.OPTIMIZE`
|
||||
|
||||
### `/ai/optimize_line_profiler` — Line-Profiler-Guided Candidates
|
||||
|
||||
Method: `optimize_python_code_line_profiler()`
|
||||
|
||||
Like `/optimize` but includes `line_profiler_results` to guide the LLM toward hot lines.
|
||||
|
||||
Returns: candidates with `source=OptimizedCandidateSource.OPTIMIZE_LP`
|
||||
|
||||
### `/ai/refine` — Refine Existing Candidate
|
||||
|
||||
Method: `refine_code()`
|
||||
|
||||
Request type: `AIServiceRefinerRequest`
|
||||
|
||||
Sends an existing candidate with runtime data and line profiler results to generate an improved version.
|
||||
|
||||
Key fields:
|
||||
- `original_source_code` / `optimized_source_code` — Before and after
|
||||
- `original_code_runtime` / `optimized_code_runtime` — Timing data
|
||||
- `speedup` — Current speedup ratio
|
||||
- `original_line_profiler_results` / `optimized_line_profiler_results`
|
||||
|
||||
Returns: candidates with `source=OptimizedCandidateSource.REFINE` and `parent_id` set to the refined candidate's ID
|
||||
|
||||
### `/ai/repair` — Fix Failed Candidate
|
||||
|
||||
Method: `repair_code()`
|
||||
|
||||
Request type: `AIServiceCodeRepairRequest`
|
||||
|
||||
Sends a failed candidate with test diffs showing what went wrong.
|
||||
|
||||
Key fields:
|
||||
- `original_source_code` / `modified_source_code`
|
||||
- `test_diffs: list[TestDiff]` — Each with `scope` (return_value/stdout/did_pass), original vs candidate values, and test source code
|
||||
|
||||
Returns: candidates with `source=OptimizedCandidateSource.REPAIR` and `parent_id` set
|
||||
|
||||
### `/ai/adaptive_optimize` — Multi-Candidate Adaptive
|
||||
|
||||
Method: `adaptive_optimize()`
|
||||
|
||||
Request type: `AIServiceAdaptiveOptimizeRequest`
|
||||
|
||||
Sends multiple previous candidates with their speedups for the LLM to learn from and generate better candidates.
|
||||
|
||||
Key fields:
|
||||
- `candidates: list[AdaptiveOptimizedCandidate]` — Previous candidates with source code, explanation, source type, and speedup
|
||||
|
||||
Returns: candidates with `source=OptimizedCandidateSource.ADAPTIVE`
|
||||
|
||||
### `/ai/rewrite_jit` — JIT Rewrite
|
||||
|
||||
Method: `get_jit_rewritten_code()`
|
||||
|
||||
Rewrites code to use JIT compilation (e.g., Numba).
|
||||
|
||||
Returns: candidates with `source=OptimizedCandidateSource.JIT_REWRITE`
|
||||
|
||||
## Candidate Parsing
|
||||
|
||||
All endpoints return JSON with an `optimizations` array. Each entry has:
|
||||
- `source_code` — Markdown-formatted code blocks
|
||||
- `explanation` — LLM explanation
|
||||
- `optimization_id` — Unique ID
|
||||
- `parent_id` — Optional parent reference
|
||||
- `model` — Which LLM model was used
|
||||
|
||||
`_get_valid_candidates()` parses the markdown code via `CodeStringsMarkdown.parse_markdown_code()` and filters out entries with empty code blocks.
|
||||
|
||||
## `LocalAiServiceClient`
|
||||
|
||||
Used when `CODEFLASH_EXPERIMENT_ID` is set. Mirrors `AiServiceClient` but sends to a separate experimental endpoint for A/B testing optimization strategies.
|
||||
|
||||
## LLM Call Sequencing
|
||||
|
||||
`AiServiceClient` tracks call sequence via `llm_call_counter` (itertools.count). Each request includes a `call_sequence` number, used by the backend to maintain conversation context across multiple calls for the same function.
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
# Configuration
|
||||
|
||||
Key configuration constants, effort levels, and thresholds.
|
||||
|
||||
## Constants (`code_utils/config_consts.py`)
|
||||
|
||||
### Test Execution
|
||||
|
||||
| Constant | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `MAX_TEST_RUN_ITERATIONS` | 5 | Maximum test loop iterations |
|
||||
| `INDIVIDUAL_TESTCASE_TIMEOUT` | 15s | Timeout per individual test case |
|
||||
| `MAX_FUNCTION_TEST_SECONDS` | 60s | Max total time for function testing |
|
||||
| `MAX_TEST_FUNCTION_RUNS` | 50 | Max test function executions |
|
||||
| `MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS` | 100ms | Max cumulative test runtime |
|
||||
| `TOTAL_LOOPING_TIME` | 10s | Candidate benchmarking budget |
|
||||
| `MIN_TESTCASE_PASSED_THRESHOLD` | 6 | Minimum test cases that must pass |
|
||||
|
||||
### Performance Thresholds
|
||||
|
||||
| Constant | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `MIN_IMPROVEMENT_THRESHOLD` | 0.05 (5%) | Minimum speedup to accept a candidate |
|
||||
| `MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD` | 0.10 (10%) | Minimum async throughput improvement |
|
||||
| `MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD` | 0.20 (20%) | Minimum concurrency ratio improvement |
|
||||
| `COVERAGE_THRESHOLD` | 60.0% | Minimum test coverage |
|
||||
|
||||
### Stability Thresholds
|
||||
|
||||
| Constant | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `STABILITY_WINDOW_SIZE` | 0.35 | 35% of total iteration window |
|
||||
| `STABILITY_CENTER_TOLERANCE` | 0.0025 | ±0.25% around median |
|
||||
| `STABILITY_SPREAD_TOLERANCE` | 0.0025 | 0.25% window spread |
|
||||
|
||||
### Context Limits
|
||||
|
||||
| Constant | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `OPTIMIZATION_CONTEXT_TOKEN_LIMIT` | 16000 | Max tokens for optimization context |
|
||||
| `TESTGEN_CONTEXT_TOKEN_LIMIT` | 16000 | Max tokens for test generation context |
|
||||
| `MAX_CONTEXT_LEN_REVIEW` | 1000 | Max context length for optimization review |
|
||||
|
||||
### Other
|
||||
|
||||
| Constant | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `MIN_CORRECT_CANDIDATES` | 2 | Min correct candidates before skipping repair |
|
||||
| `REPEAT_OPTIMIZATION_PROBABILITY` | 0.1 | Probability of re-optimizing a function |
|
||||
| `DEFAULT_IMPORTANCE_THRESHOLD` | 0.001 | Minimum addressable time to consider a function |
|
||||
| `CONCURRENCY_FACTOR` | 10 | Number of concurrent executions for concurrency benchmark |
|
||||
| `REFINED_CANDIDATE_RANKING_WEIGHTS` | (2, 1) | (runtime, diff) weights — runtime 2x more important |
|
||||
|
||||
## Effort Levels
|
||||
|
||||
`EffortLevel` enum: `LOW`, `MEDIUM`, `HIGH`
|
||||
|
||||
Effort controls the number of candidates, repairs, and refinements:
|
||||
|
||||
| Key | LOW | MEDIUM | HIGH |
|
||||
|-----|-----|--------|------|
|
||||
| `N_OPTIMIZER_CANDIDATES` | 3 | 5 | 6 |
|
||||
| `N_OPTIMIZER_LP_CANDIDATES` | 4 | 6 | 7 |
|
||||
| `N_GENERATED_TESTS` | 2 | 2 | 2 |
|
||||
| `MAX_CODE_REPAIRS_PER_TRACE` | 2 | 3 | 5 |
|
||||
| `REPAIR_UNMATCHED_PERCENTAGE_LIMIT` | 0.2 | 0.3 | 0.4 |
|
||||
| `TOP_VALID_CANDIDATES_FOR_REFINEMENT` | 2 | 3 | 4 |
|
||||
| `ADAPTIVE_OPTIMIZATION_THRESHOLD` | 0 | 0 | 2 |
|
||||
| `MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE` | 0 | 0 | 4 |
|
||||
|
||||
Use `get_effort_value(EffortKeys.KEY, effort_level)` to retrieve values.
|
||||
|
||||
## Project Configuration
|
||||
|
||||
Configuration is read from `pyproject.toml` under `[tool.codeflash]`. Key settings are auto-detected by `setup/detector.py`:
|
||||
- `module-root` — Root of the module to optimize
|
||||
- `tests-root` — Root of test files
|
||||
- `test-framework` — pytest, unittest, jest, etc.
|
||||
- `formatter-cmds` — Code formatting commands
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
# Context Extraction
|
||||
|
||||
How codeflash extracts and limits code context for optimization and test generation.
|
||||
|
||||
## Overview
|
||||
|
||||
Context extraction (`context/code_context_extractor.py`) builds a `CodeOptimizationContext` containing all code needed for the LLM to understand and optimize a function, split into:
|
||||
|
||||
- **Read-writable code** (`CodeContextType.READ_WRITABLE`): The function being optimized plus its helper functions — code the LLM is allowed to modify
|
||||
- **Read-only context** (`CodeContextType.READ_ONLY`): Dependency code for reference — imports, type definitions, base classes
|
||||
- **Testgen context** (`CodeContextType.TESTGEN`): Context for test generation, may include imported class definitions and external base class inits
|
||||
- **Hashing context** (`CodeContextType.HASHING`): Used for deduplication of optimization runs
|
||||
|
||||
## Token Limits
|
||||
|
||||
Both optimization and test generation contexts are token-limited:
|
||||
- `OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000` tokens
|
||||
- `TESTGEN_CONTEXT_TOKEN_LIMIT = 16000` tokens
|
||||
|
||||
Token counting uses `encoded_tokens_len()` from `code_utils/code_utils.py`. Functions whose context exceeds these limits are skipped.
|
||||
|
||||
## Context Building Process
|
||||
|
||||
### 1. Helper Discovery
|
||||
|
||||
For the target function (`FunctionToOptimize`), the extractor finds:
|
||||
- **Helpers of the function**: Functions/classes in the same file that the target function calls
|
||||
- **Helpers of helpers**: Transitive dependencies of the helper functions
|
||||
|
||||
These are organized as `dict[Path, set[FunctionSource]]` — mapping file paths to the set of helper functions found in each file.
|
||||
|
||||
### 2. Code Extraction
|
||||
|
||||
`extract_code_markdown_context_from_files()` builds `CodeStringsMarkdown` from the helper dictionaries. Each file's relevant code is extracted as a `CodeString` with its file path.
|
||||
|
||||
### 3. Testgen Context Enrichment
|
||||
|
||||
`build_testgen_context()` extends the basic context with:
|
||||
- Imported class definitions (resolved from imports)
|
||||
- External base class `__init__` methods
|
||||
- External class `__init__` methods referenced in the context
|
||||
|
||||
### 4. Unused Definition Removal
|
||||
|
||||
`detect_unused_helper_functions()` and `remove_unused_definitions_by_function_names()` from `context/unused_definition_remover.py` prune definitions that are not transitively reachable from the target function, reducing token usage.
|
||||
|
||||
### 5. Deduplication
|
||||
|
||||
The hashing context (`hashing_code_context`) generates a hash (`hashing_code_context_hash`) used to detect when the same function context has already been optimized in a previous run, avoiding redundant work.
|
||||
|
||||
## Key Functions
|
||||
|
||||
| Function | Location | Purpose |
|
||||
|----------|----------|---------|
|
||||
| `build_testgen_context()` | `context/code_context_extractor.py` | Build enriched testgen context |
|
||||
| `extract_code_markdown_context_from_files()` | `context/code_context_extractor.py` | Convert helper dicts to `CodeStringsMarkdown` |
|
||||
| `detect_unused_helper_functions()` | `context/unused_definition_remover.py` | Find unused definitions |
|
||||
| `remove_unused_definitions_by_function_names()` | `context/unused_definition_remover.py` | Remove unused definitions |
|
||||
| `collect_top_level_defs_with_usages()` | `context/unused_definition_remover.py` | Analyze definition usage |
|
||||
| `encoded_tokens_len()` | `code_utils/code_utils.py` | Count tokens in code |
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
# Domain Types
|
||||
|
||||
Core data types used throughout the codeflash optimization pipeline.
|
||||
|
||||
## Function Representation
|
||||
|
||||
### `FunctionToOptimize` (`models/function_types.py`)
|
||||
|
||||
The canonical dataclass representing a function candidate for optimization. Works across Python, JavaScript, and TypeScript.
|
||||
|
||||
Key fields:
|
||||
- `function_name: str` — The function name
|
||||
- `file_path: Path` — Absolute file path where the function is located
|
||||
- `parents: list[FunctionParent]` — Parent scopes (classes/functions), each with `name` and `type`
|
||||
- `starting_line / ending_line: Optional[int]` — Line range (1-indexed)
|
||||
- `is_async: bool` — Whether the function is async
|
||||
- `is_method: bool` — Whether it belongs to a class
|
||||
- `language: str` — Programming language (default: `"python"`)
|
||||
|
||||
Key properties:
|
||||
- `qualified_name` — Full dotted name including parent classes (e.g., `MyClass.my_method`)
|
||||
- `top_level_parent_name` — Name of outermost parent, or function name if no parents
|
||||
- `class_name` — Immediate parent class name, or `None`
|
||||
|
||||
### `FunctionParent` (`models/function_types.py`)
|
||||
|
||||
Represents a parent scope: `name: str` (e.g., `"MyClass"`) and `type: str` (e.g., `"ClassDef"`).
|
||||
|
||||
### `FunctionSource` (`models/models.py`)
|
||||
|
||||
Represents a resolved function with source code. Used for helper functions in context extraction.
|
||||
|
||||
Fields: `file_path`, `qualified_name`, `fully_qualified_name`, `only_function_name`, `source_code`, `jedi_definition`.
|
||||
|
||||
## Code Representation
|
||||
|
||||
### `CodeString` (`models/models.py`)
|
||||
|
||||
A single code block with validated syntax:
|
||||
- `code: str` — The source code
|
||||
- `file_path: Optional[Path]` — Origin file path
|
||||
- `language: str` — Language for validation (default: `"python"`)
|
||||
|
||||
Validates syntax on construction via `model_validator`.
|
||||
|
||||
### `CodeStringsMarkdown` (`models/models.py`)
|
||||
|
||||
A collection of `CodeString` blocks — the primary format for passing code through the pipeline.
|
||||
|
||||
Key properties:
|
||||
- `.flat` — Combined source code with file-path comment prefixes (e.g., `# file: path/to/file.py`)
|
||||
- `.markdown` — Markdown-formatted with fenced code blocks: `` ```python:filepath\ncode\n``` ``
|
||||
- `.file_to_path()` — Dict mapping file path strings to code
|
||||
|
||||
Static method:
|
||||
- `parse_markdown_code(markdown_code, expected_language)` — Parses markdown code blocks back into `CodeStringsMarkdown`
|
||||
|
||||
## Optimization Context
|
||||
|
||||
### `CodeOptimizationContext` (`models/models.py`)
|
||||
|
||||
Holds all code context needed for optimization:
|
||||
- `read_writable_code: CodeStringsMarkdown` — Code the LLM can modify
|
||||
- `read_only_context_code: str` — Reference-only dependency code
|
||||
- `testgen_context: CodeStringsMarkdown` — Context for test generation
|
||||
- `hashing_code_context: str` / `hashing_code_context_hash: str` — For deduplication
|
||||
- `helper_functions: list[FunctionSource]` — Helper functions in the writable code
|
||||
- `preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]` — Objects that already exist in the code
|
||||
|
||||
### `CodeContextType` enum (`models/models.py`)
|
||||
|
||||
Defines context categories: `READ_WRITABLE`, `READ_ONLY`, `TESTGEN`, `HASHING`.
|
||||
|
||||
## Candidates
|
||||
|
||||
### `OptimizedCandidate` (`models/models.py`)
|
||||
|
||||
A generated code variant:
|
||||
- `source_code: CodeStringsMarkdown` — The optimized code
|
||||
- `explanation: str` — LLM explanation of the optimization
|
||||
- `optimization_id: str` — Unique identifier
|
||||
- `source: OptimizedCandidateSource` — How it was generated
|
||||
- `parent_id: str | None` — ID of parent candidate (for refinements/repairs)
|
||||
- `model: str | None` — Which LLM model generated it
|
||||
|
||||
### `OptimizedCandidateSource` enum (`models/models.py`)
|
||||
|
||||
How a candidate was generated: `OPTIMIZE`, `OPTIMIZE_LP` (line profiler), `REFINE`, `REPAIR`, `ADAPTIVE`, `JIT_REWRITE`.
|
||||
|
||||
### `CandidateEvaluationContext` (`models/models.py`)
|
||||
|
||||
Tracks state during candidate evaluation:
|
||||
- `speedup_ratios` / `optimized_runtimes` / `is_correct` — Per-candidate results
|
||||
- `ast_code_to_id` — Deduplication map (normalized AST → first seen candidate)
|
||||
- `valid_optimizations` — Candidates that passed all checks
|
||||
|
||||
Key methods: `record_failed_candidate()`, `record_successful_candidate()`, `handle_duplicate_candidate()`, `register_new_candidate()`.
|
||||
|
||||
## Baseline & Results
|
||||
|
||||
### `OriginalCodeBaseline` (`models/models.py`)
|
||||
|
||||
Baseline measurements for the original code:
|
||||
- `behavior_test_results: TestResults` / `benchmarking_test_results: TestResults`
|
||||
- `line_profile_results: dict`
|
||||
- `runtime: int` — Total runtime in nanoseconds
|
||||
- `coverage_results: Optional[CoverageData]`
|
||||
|
||||
### `BestOptimization` (`models/models.py`)
|
||||
|
||||
The winning candidate after evaluation:
|
||||
- `candidate: OptimizedCandidate`
|
||||
- `helper_functions: list[FunctionSource]`
|
||||
- `code_context: CodeOptimizationContext`
|
||||
- `runtime: int`
|
||||
- `winning_behavior_test_results` / `winning_benchmarking_test_results: TestResults`
|
||||
|
||||
## Test Types
|
||||
|
||||
### `TestType` enum (`models/test_type.py`)
|
||||
|
||||
- `EXISTING_UNIT_TEST` (1) — Pre-existing tests from the codebase
|
||||
- `INSPIRED_REGRESSION` (2) — Tests inspired by existing tests
|
||||
- `GENERATED_REGRESSION` (3) — AI-generated regression tests
|
||||
- `REPLAY_TEST` (4) — Tests from recorded benchmark data
|
||||
- `CONCOLIC_COVERAGE_TEST` (5) — Coverage-guided tests
|
||||
- `INIT_STATE_TEST` (6) — Class init state verification
|
||||
|
||||
### `TestFile` / `TestFiles` (`models/models.py`)
|
||||
|
||||
`TestFile` represents a single test file with `instrumented_behavior_file_path`, optional `benchmarking_file_path`, `original_file_path`, `test_type`, and `tests_in_file`.
|
||||
|
||||
`TestFiles` is a collection with lookup methods: `get_by_type()`, `get_by_original_file_path()`, `get_test_type_by_instrumented_file_path()`.
|
||||
|
||||
### `TestResults` (`models/models.py`)
|
||||
|
||||
Collection of `FunctionTestInvocation` results with indexed lookup. Key methods:
|
||||
- `add(invocation)` — Deduplicated insert
|
||||
- `total_passed_runtime()` — Sum of minimum runtimes per test case (nanoseconds)
|
||||
- `number_of_loops()` — Max loop index across all results
|
||||
- `usable_runtime_data_by_test_case()` — Dict of invocation ID → list of runtimes
|
||||
|
||||
## Result Type
|
||||
|
||||
### `Result[L, R]` / `Success` / `Failure` (`either.py`)
|
||||
|
||||
Functional error handling type:
|
||||
- `Success(value)` — Wraps a successful result
|
||||
- `Failure(error)` — Wraps an error
|
||||
- `result.is_successful()` / `result.is_failure()` — Check type
|
||||
- `result.unwrap()` — Get success value (raises if Failure)
|
||||
- `result.failure()` — Get failure value (raises if Success)
|
||||
- `is_successful(result)` — Module-level helper function
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
# Codeflash Internal Documentation
|
||||
|
||||
CodeFlash is an AI-powered Python code optimizer that automatically improves code performance while maintaining correctness. It uses LLMs to generate optimization candidates, verifies correctness through test execution, and benchmarks performance improvements.
|
||||
|
||||
## Pipeline Overview
|
||||
|
||||
```
|
||||
Discovery → Ranking → Context Extraction → Test Gen + Optimization → Baseline → Candidate Evaluation → PR
|
||||
```
|
||||
|
||||
1. **Discovery** (`discovery/`): Find optimizable functions across the codebase using `FunctionVisitor`
|
||||
2. **Ranking** (`benchmarking/function_ranker.py`): Rank functions by addressable time using trace data
|
||||
3. **Context** (`context/`): Extract code dependencies — split into read-writable (modifiable) and read-only (reference)
|
||||
4. **Optimization** (`optimization/`, `api/`): Generate candidates via AI service, runs concurrently with test generation
|
||||
5. **Verification** (`verification/`): Run candidates against tests via custom pytest plugin, compare outputs
|
||||
6. **Benchmarking** (`benchmarking/`): Measure performance, select best candidate by speedup
|
||||
7. **Result** (`result/`, `github/`): Create PR with winning optimization
|
||||
|
||||
## Key Entry Points
|
||||
|
||||
| Task | File |
|
||||
|------|------|
|
||||
| CLI arguments & commands | `cli_cmds/cli.py` |
|
||||
| Optimization orchestration | `optimization/optimizer.py` → `Optimizer.run()` |
|
||||
| Per-function optimization | `optimization/function_optimizer.py` → `FunctionOptimizer` |
|
||||
| Function discovery | `discovery/functions_to_optimize.py` |
|
||||
| Context extraction | `context/code_context_extractor.py` |
|
||||
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
|
||||
| Performance ranking | `benchmarking/function_ranker.py` |
|
||||
| Domain types | `models/models.py`, `models/function_types.py` |
|
||||
| AI service | `api/aiservice.py` → `AiServiceClient` |
|
||||
| Configuration | `code_utils/config_consts.py` |
|
||||
|
||||
## Documentation Pages
|
||||
|
||||
- [Domain Types](domain-types.md) — Core data types and their relationships
|
||||
- [Optimization Pipeline](optimization-pipeline.md) — Step-by-step data flow through the pipeline
|
||||
- [Context Extraction](context-extraction.md) — How code context is extracted and token-limited
|
||||
- [Verification](verification.md) — Test execution, pytest plugin, deterministic patches
|
||||
- [AI Service](ai-service.md) — AI service client endpoints and request types
|
||||
- [Configuration](configuration.md) — Config schema, effort levels, thresholds
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
# Optimization Pipeline
|
||||
|
||||
Step-by-step data flow from function discovery to PR creation.
|
||||
|
||||
## 1. Entry Point: `Optimizer.run()` (`optimization/optimizer.py`)
|
||||
|
||||
The `Optimizer` class is initialized with CLI args and creates:
|
||||
- `TestConfig` with test roots, project root, pytest command
|
||||
- `AiServiceClient` for AI service communication
|
||||
- Optional `LocalAiServiceClient` for experiments
|
||||
|
||||
`run()` orchestrates the full pipeline: discovers functions, optionally ranks them, then optimizes each in turn.
|
||||
|
||||
## 2. Function Discovery (`discovery/functions_to_optimize.py`)
|
||||
|
||||
`FunctionVisitor` traverses source files to find optimizable functions, producing `FunctionToOptimize` instances. Filters include:
|
||||
- Skipping functions that are too small or trivial
|
||||
- Skipping previously optimized functions (via `was_function_previously_optimized()`)
|
||||
- Applying user-configured include/exclude patterns
|
||||
|
||||
## 3. Function Ranking (`benchmarking/function_ranker.py`)
|
||||
|
||||
When trace data is available, `FunctionRanker` ranks functions by **addressable time** — the time a function spends that could be optimized (own time + callee time / call count). Functions below `DEFAULT_IMPORTANCE_THRESHOLD=0.001` are skipped.
|
||||
|
||||
## 4. Per-Function Optimization: `FunctionOptimizer` (`optimization/function_optimizer.py`)
|
||||
|
||||
For each function, `FunctionOptimizer.optimize_function()` runs the full optimization loop:
|
||||
|
||||
### 4a. Context Extraction (`context/code_context_extractor.py`)
|
||||
|
||||
Extracts `CodeOptimizationContext` containing:
|
||||
- `read_writable_code` — Code the LLM can modify (the function + helpers)
|
||||
- `read_only_context_code` — Dependency code for reference only
|
||||
- `testgen_context` — Context for test generation (may include imported class definitions)
|
||||
|
||||
Token limits are enforced: `OPTIMIZATION_CONTEXT_TOKEN_LIMIT=16000` and `TESTGEN_CONTEXT_TOKEN_LIMIT=16000`. Functions exceeding these are rejected.
|
||||
|
||||
### 4b. Concurrent Test Generation + LLM Optimization
|
||||
|
||||
These run in parallel using `concurrent.futures`:
|
||||
- **Test generation**: Generates regression tests from the function context
|
||||
- **LLM optimization**: Sends `read_writable_code.markdown` + `read_only_context_code` to the AI service
|
||||
|
||||
The number of candidates depends on effort level (see Configuration docs).
|
||||
|
||||
### 4c. Candidate Evaluation
|
||||
|
||||
For each `OptimizedCandidate`:
|
||||
|
||||
1. **Deduplication**: Normalize code AST and check against `CandidateEvaluationContext.ast_code_to_id`. If duplicate, copy results from previous evaluation.
|
||||
|
||||
2. **Code replacement**: Replace the original function with the candidate using `replace_function_definitions_in_module()`.
|
||||
|
||||
3. **Behavioral testing**: Run instrumented tests in subprocess. The custom pytest plugin applies deterministic patches. Compare return values, stdout, and pass/fail status against the original baseline.
|
||||
|
||||
4. **Benchmarking**: If behavior matches, run performance tests with looping (`TOTAL_LOOPING_TIME=10s`). Calculate speedup ratio.
|
||||
|
||||
5. **Validation**: Candidate must beat `MIN_IMPROVEMENT_THRESHOLD=0.05` (5% speedup) and pass stability checks.
|
||||
|
||||
### 4d. Refinement & Repair
|
||||
|
||||
- **Repair**: If fewer than `MIN_CORRECT_CANDIDATES=2` pass, failed candidates can be repaired via `AIServiceCodeRepairRequest` (sends test diffs to LLM).
|
||||
- **Refinement**: Top valid candidates are refined via `AIServiceRefinerRequest` (sends runtime data, line profiler results).
|
||||
- **Adaptive**: At HIGH effort, additional adaptive optimization rounds via `AIServiceAdaptiveOptimizeRequest`.
|
||||
|
||||
### 4e. Best Candidate Selection
|
||||
|
||||
The winning candidate is selected by:
|
||||
1. Highest speedup ratio
|
||||
2. For tied speedups, shortest diff length from original
|
||||
3. Refinement candidates use weighted ranking: `(2 * runtime_rank + 1 * diff_rank)`
|
||||
|
||||
Result is a `BestOptimization` with the candidate, context, test results, and runtime.
|
||||
|
||||
## 5. PR Creation (`github/`)
|
||||
|
||||
If a winning candidate is found, a PR is created with:
|
||||
- The optimized code diff
|
||||
- Performance benchmark details
|
||||
- Explanation from the LLM
|
||||
|
||||
## Worktree Mode
|
||||
|
||||
When `--worktree` is enabled, optimization runs in an isolated git worktree (`code_utils/git_worktree_utils.py`). This allows parallel optimization without affecting the working tree. Changes are captured as patch files.
|
||||
|
|
@ -1,93 +0,0 @@
|
|||
# Verification
|
||||
|
||||
How codeflash verifies candidate correctness and measures performance.
|
||||
|
||||
## Test Execution Architecture
|
||||
|
||||
Tests are executed in a **subprocess** to isolate the test environment from the main codeflash process. The test runner (`verification/test_runner.py`) invokes pytest (or Jest for JS/TS) with specific plugin configurations.
|
||||
|
||||
### Plugin Blocklists
|
||||
|
||||
- **Behavioral tests**: Block `benchmark`, `codspeed`, `xdist`, `sugar`
|
||||
- **Benchmarking tests**: Block `codspeed`, `cov`, `benchmark`, `profiling`, `xdist`, `sugar`
|
||||
|
||||
These are defined as `BEHAVIORAL_BLOCKLISTED_PLUGINS` and `BENCHMARKING_BLOCKLISTED_PLUGINS` in `verification/test_runner.py`.
|
||||
|
||||
## Custom Pytest Plugin (`verification/pytest_plugin.py`)
|
||||
|
||||
The plugin is loaded into the test subprocess and provides:
|
||||
|
||||
### Deterministic Patches
|
||||
|
||||
`_apply_deterministic_patches()` replaces non-deterministic functions with fixed values to ensure reproducible test output:
|
||||
|
||||
| Module | Function | Fixed Value |
|
||||
|--------|----------|-------------|
|
||||
| `time` | `time()` | `1761717605.108106` |
|
||||
| `time` | `perf_counter()` | Incrementing by 1ms per call |
|
||||
| `datetime` | `datetime.now()` | `2021-01-01 02:05:10 UTC` |
|
||||
| `datetime` | `datetime.utcnow()` | `2021-01-01 02:05:10 UTC` |
|
||||
| `uuid` | `uuid4()` / `uuid1()` | `12345678-1234-5678-9abc-123456789012` |
|
||||
| `random` | `random()` | `0.123456789` (seeded with 42) |
|
||||
| `os` | `urandom(n)` | `b"\x42" * n` |
|
||||
| `numpy.random` | seed | `42` |
|
||||
|
||||
Patches call the original function first to maintain performance characteristics (same call overhead).
|
||||
|
||||
### Timing Markers
|
||||
|
||||
Test results include timing markers in stdout: `!######<id>:<duration_ns>######!`
|
||||
|
||||
The pattern `_TIMING_MARKER_PATTERN` extracts timing data for calculating function utilization fraction.
|
||||
|
||||
### Loop Stability
|
||||
|
||||
Performance benchmarking uses configurable stability thresholds:
|
||||
- `STABILITY_WINDOW_SIZE = 0.35` (35% of total iterations)
|
||||
- `STABILITY_CENTER_TOLERANCE = 0.0025` (±0.25% around median)
|
||||
- `STABILITY_SPREAD_TOLERANCE = 0.0025` (0.25% window spread)
|
||||
|
||||
### Memory Limits (Linux)
|
||||
|
||||
On Linux, the plugin sets `RLIMIT_AS` to 85% of total system memory (RAM + swap) to prevent OOM kills.
|
||||
|
||||
## Test Result Processing
|
||||
|
||||
### `TestResults` (`models/models.py`)
|
||||
|
||||
Collects `FunctionTestInvocation` results with:
|
||||
- Deduplicated insertion via `unique_invocation_loop_id`
|
||||
- `total_passed_runtime()` — Sum of minimum runtimes per test case (nanoseconds)
|
||||
- `number_of_loops()` — Max loop index
|
||||
- `usable_runtime_data_by_test_case()` — Grouped timing data
|
||||
|
||||
### `FunctionTestInvocation`
|
||||
|
||||
Each invocation records:
|
||||
- `loop_index` — Iteration number (starts at 1)
|
||||
- `id: InvocationId` — Fully qualified test identifier
|
||||
- `did_pass: bool` — Pass/fail status
|
||||
- `runtime: Optional[int]` — Time in nanoseconds
|
||||
- `return_value: Optional[object]` — Captured return value
|
||||
- `test_type: TestType` — Which test category
|
||||
|
||||
### Behavioral vs Performance Testing
|
||||
|
||||
1. **Behavioral**: Runs with `TestingMode.BEHAVIOR`. Compares return values and stdout between original and candidate. Any difference = candidate rejected.
|
||||
2. **Performance**: Runs with `TestingMode.PERFORMANCE`. Loops for `TOTAL_LOOPING_TIME=10s` to get stable timing. Calculates speedup ratio.
|
||||
3. **Line Profile**: Runs with `TestingMode.LINE_PROFILE`. Collects per-line timing data for refinement.
|
||||
|
||||
## Test Types
|
||||
|
||||
| TestType | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `EXISTING_UNIT_TEST` | 1 | Pre-existing tests from the codebase |
|
||||
| `INSPIRED_REGRESSION` | 2 | Tests inspired by existing tests |
|
||||
| `GENERATED_REGRESSION` | 3 | AI-generated regression tests |
|
||||
| `REPLAY_TEST` | 4 | Tests from recorded benchmark data |
|
||||
| `CONCOLIC_COVERAGE_TEST` | 5 | Coverage-guided tests |
|
||||
| `INIT_STATE_TEST` | 6 | Class init state verification |
|
||||
|
||||
## Coverage
|
||||
|
||||
Coverage is measured via `CoverageData` with a threshold of `COVERAGE_THRESHOLD=60.0%`. Low coverage may affect confidence in the optimization's correctness.
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
{
|
||||
"package_name": "codeflash-docs",
|
||||
"total_capabilities": 16,
|
||||
"capabilities": [
|
||||
{
|
||||
"id": 0,
|
||||
"name": "pipeline-stage-ordering",
|
||||
"description": "Know the correct ordering of codeflash pipeline stages: Discovery → Ranking → Context Extraction → Test Gen + Optimization (concurrent) → Baseline → Candidate Evaluation → PR",
|
||||
"complexity": "basic",
|
||||
"api_elements": ["Optimizer.run()", "FunctionOptimizer.optimize_function()"]
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"name": "function-to-optimize-fields",
|
||||
"description": "Know FunctionToOptimize key fields (function_name, file_path, parents, starting_line/ending_line, is_async, is_method, language) and properties (qualified_name, top_level_parent_name, class_name)",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["FunctionToOptimize", "FunctionParent", "models/function_types.py"]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "code-strings-markdown-format",
|
||||
"description": "Know that code is serialized as markdown fenced blocks with language:filepath syntax (```python:filepath\\ncode\\n```) and parsed via CodeStringsMarkdown.parse_markdown_code()",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["CodeStringsMarkdown", "CodeString", ".markdown", ".flat", "parse_markdown_code()"]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"name": "read-writable-vs-read-only",
|
||||
"description": "Distinguish read_writable_code (LLM can modify) from read_only_context_code (reference only) in CodeOptimizationContext",
|
||||
"complexity": "basic",
|
||||
"api_elements": ["CodeOptimizationContext", "read_writable_code", "read_only_context_code"]
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"name": "candidate-source-types",
|
||||
"description": "Know OptimizedCandidateSource variants: OPTIMIZE, OPTIMIZE_LP, REFINE, REPAIR, ADAPTIVE, JIT_REWRITE and when each is used",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["OptimizedCandidateSource", "OptimizedCandidate"]
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"name": "candidate-forest-dag",
|
||||
"description": "Know that candidates form a forest/DAG via parent_id references where refinements and repairs build on previous candidates",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["parent_id", "OptimizedCandidate", "CandidateForest"]
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"name": "concurrent-testgen-optimization",
|
||||
"description": "Know that test generation and LLM optimization run concurrently using concurrent.futures, not sequentially",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["concurrent.futures", "FunctionOptimizer.optimize_function()"]
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"name": "deterministic-patch-values",
|
||||
"description": "Know the specific fixed values used by deterministic patches: time=1761717605.108106, datetime=2021-01-01 02:05:10 UTC, uuid=12345678-1234-5678-9abc-123456789012, random seeded with 42",
|
||||
"complexity": "advanced",
|
||||
"api_elements": ["_apply_deterministic_patches()", "pytest_plugin.py"]
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"name": "test-type-enum",
|
||||
"description": "Know the 6 TestType variants: EXISTING_UNIT_TEST, INSPIRED_REGRESSION, GENERATED_REGRESSION, REPLAY_TEST, CONCOLIC_COVERAGE_TEST, INIT_STATE_TEST",
|
||||
"complexity": "basic",
|
||||
"api_elements": ["TestType", "models/test_type.py"]
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"name": "ai-service-endpoints",
|
||||
"description": "Know the AI service endpoints: /ai/optimize, /ai/optimize_line_profiler, /ai/refine, /ai/repair, /ai/adaptive_optimize, /ai/rewrite_jit",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["AiServiceClient", "api/aiservice.py"]
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"name": "repair-request-structure",
|
||||
"description": "Know that AIServiceCodeRepairRequest includes TestDiff objects with scope (RETURN_VALUE/STDOUT/DID_PASS), original vs candidate values, and test source code",
|
||||
"complexity": "advanced",
|
||||
"api_elements": ["AIServiceCodeRepairRequest", "TestDiff", "TestDiffScope"]
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"name": "effort-level-values",
|
||||
"description": "Know specific effort level values: LOW gets 3 candidates, MEDIUM gets 5, HIGH gets 6 (N_OPTIMIZER_CANDIDATES)",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["EffortLevel", "N_OPTIMIZER_CANDIDATES", "EFFORT_VALUES"]
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"name": "context-token-limits",
|
||||
"description": "Know OPTIMIZATION_CONTEXT_TOKEN_LIMIT=16000 and TESTGEN_CONTEXT_TOKEN_LIMIT=16000 and that encoded_tokens_len() is used for counting",
|
||||
"complexity": "basic",
|
||||
"api_elements": ["OPTIMIZATION_CONTEXT_TOKEN_LIMIT", "TESTGEN_CONTEXT_TOKEN_LIMIT", "encoded_tokens_len()"]
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"name": "best-candidate-selection",
|
||||
"description": "Know the selection criteria: highest speedup, then shortest diff for ties, and refinement weighted ranking (2*runtime + 1*diff)",
|
||||
"complexity": "advanced",
|
||||
"api_elements": ["BestOptimization", "REFINED_CANDIDATE_RANKING_WEIGHTS"]
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
"name": "plugin-blocklists",
|
||||
"description": "Know behavioral test blocklisted plugins (benchmark, codspeed, xdist, sugar) and benchmarking blocklist (adds cov, profiling)",
|
||||
"complexity": "intermediate",
|
||||
"api_elements": ["BEHAVIORAL_BLOCKLISTED_PLUGINS", "BENCHMARKING_BLOCKLISTED_PLUGINS"]
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
"name": "result-type-usage",
|
||||
"description": "Know that Result[L,R] from either.py uses Success(value)/Failure(error) with is_successful() check before unwrap()",
|
||||
"complexity": "basic",
|
||||
"api_elements": ["Result", "Success", "Failure", "is_successful", "either.py"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
Code serialization format and context splitting
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
{
|
||||
"context": "Tests whether the agent knows the CodeStringsMarkdown serialization format and the distinction between read-writable and read-only code context in the codeflash pipeline.",
|
||||
"type": "weighted_checklist",
|
||||
"checklist": [
|
||||
{
|
||||
"name": "Markdown code block format",
|
||||
"description": "Uses the correct fenced code block format with language:filepath syntax (```python:path/to/file.py) when constructing code for the AI service, NOT plain code blocks without file paths",
|
||||
"max_score": 30
|
||||
},
|
||||
{
|
||||
"name": "Read-writable vs read-only split",
|
||||
"description": "Correctly separates code into read_writable_code (code the LLM can modify) and read_only_context_code (reference-only dependency code), NOT treating all code as modifiable",
|
||||
"max_score": 35
|
||||
},
|
||||
{
|
||||
"name": "parse_markdown_code usage",
|
||||
"description": "Uses CodeStringsMarkdown.parse_markdown_code() to parse AI service responses back into structured code, NOT manual string splitting or regex",
|
||||
"max_score": 35
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
# Format Code for AI Service Request
|
||||
|
||||
## Context
|
||||
|
||||
You are working on the codeflash optimization engine. The AI service accepts optimization requests with source code and dependency context. A function `calculate_total` in `analytics/metrics.py` needs to be optimized. It calls a helper `normalize_values` in the same file (both modifiable), and imports `BaseMetric` from `analytics/base.py` (not modifiable, just for reference).
|
||||
|
||||
```python
|
||||
# analytics/metrics.py
|
||||
from analytics.base import BaseMetric
|
||||
|
||||
def normalize_values(data: list[float]) -> list[float]:
|
||||
max_val = max(data)
|
||||
return [x / max_val for x in data]
|
||||
|
||||
def calculate_total(metrics: list[BaseMetric]) -> float:
|
||||
values = [m.value for m in metrics]
|
||||
normalized = normalize_values(values)
|
||||
return sum(normalized)
|
||||
```
|
||||
|
||||
```python
|
||||
# analytics/base.py
|
||||
class BaseMetric:
|
||||
def __init__(self, name: str, value: float):
|
||||
self.name = name
|
||||
self.value = value
|
||||
```
|
||||
|
||||
## Task
|
||||
|
||||
Write a Python function `prepare_optimization_payload` that constructs the code payload for an AI service optimization request for `calculate_total`. It should properly format the source code and dependency code, and include a function to parse the AI service response back into structured code objects.
|
||||
|
||||
## Expected Outputs
|
||||
|
||||
- A Python file `payload_builder.py` with the payload construction and response parsing logic
|
||||
|
|
@ -1 +0,0 @@
|
|||
Candidate source types and DAG relationships
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{
|
||||
"context": "Tests whether the agent knows the different OptimizedCandidateSource types and how candidates form a DAG via parent_id references in the codeflash pipeline.",
|
||||
"type": "weighted_checklist",
|
||||
"checklist": [
|
||||
{
|
||||
"name": "Lists source types",
|
||||
"description": "Identifies at least 4 of the 6 OptimizedCandidateSource variants: OPTIMIZE, OPTIMIZE_LP, REFINE, REPAIR, ADAPTIVE, JIT_REWRITE",
|
||||
"max_score": 25
|
||||
},
|
||||
{
|
||||
"name": "Parent ID linkage",
|
||||
"description": "Explains that REFINE and REPAIR candidates reference their parent via parent_id, creating a DAG/forest structure, NOT independent candidates",
|
||||
"max_score": 25
|
||||
},
|
||||
{
|
||||
"name": "Refinement uses runtime data",
|
||||
"description": "States that refinement sends runtime data and line profiler results to the AI service (AIServiceRefinerRequest), NOT just the source code",
|
||||
"max_score": 25
|
||||
},
|
||||
{
|
||||
"name": "Repair uses test diffs",
|
||||
"description": "States that repair sends test failure diffs (TestDiff with scope: RETURN_VALUE/STDOUT/DID_PASS) to the AI service, NOT just error messages",
|
||||
"max_score": 25
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
# Document the Candidate Lifecycle
|
||||
|
||||
## Context
|
||||
|
||||
A new engineer is joining the codeflash team and needs to understand how optimization candidates are generated, improved, and related to each other throughout the pipeline. They've asked for a clear explanation of the different ways candidates are produced and how the system iterates on them.
|
||||
|
||||
## Task
|
||||
|
||||
Write a technical document explaining the full lifecycle of an optimization candidate in codeflash — from initial generation through improvement iterations. Cover all the different ways candidates can be created, what data is sent to the AI service for each type, and how candidates relate to each other structurally.
|
||||
|
||||
## Expected Outputs
|
||||
|
||||
- A markdown file `candidate-lifecycle.md`
|
||||
|
|
@ -1 +0,0 @@
|
|||
Deterministic patch values and test execution architecture
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
{
|
||||
"context": "Tests whether the agent knows the specific deterministic patch values used in codeflash's pytest plugin and the subprocess-based test execution architecture.",
|
||||
"type": "weighted_checklist",
|
||||
"checklist": [
|
||||
{
|
||||
"name": "Subprocess isolation",
|
||||
"description": "States that tests run in a subprocess to isolate the test environment from the main codeflash process, NOT in the same process",
|
||||
"max_score": 20
|
||||
},
|
||||
{
|
||||
"name": "Fixed time value",
|
||||
"description": "References the specific fixed timestamp 1761717605.108106 for time.time() or the fixed datetime 2021-01-01 02:05:10 UTC for datetime.now()",
|
||||
"max_score": 20
|
||||
},
|
||||
{
|
||||
"name": "Fixed UUID value",
|
||||
"description": "References the specific fixed UUID 12345678-1234-5678-9abc-123456789012 for uuid4/uuid1",
|
||||
"max_score": 20
|
||||
},
|
||||
{
|
||||
"name": "Random seed",
|
||||
"description": "States that random is seeded with 42 (NOT a different seed value)",
|
||||
"max_score": 20
|
||||
},
|
||||
{
|
||||
"name": "Plugin blocklists",
|
||||
"description": "Mentions that behavioral tests block specific pytest plugins (at least 2 of: benchmark, codspeed, xdist, sugar) to ensure deterministic execution",
|
||||
"max_score": 20
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
# Explain Test Reproducibility Guarantees
|
||||
|
||||
## Context
|
||||
|
||||
A codeflash user notices that their optimization candidate passes behavioral tests on one run but fails on the next. They suspect non-determinism in the test execution. They want to understand what guarantees codeflash provides for test reproducibility and how the system ensures consistent results.
|
||||
|
||||
## Task
|
||||
|
||||
Write a technical explanation of how codeflash ensures deterministic test execution. Cover the execution environment setup, what sources of non-determinism are controlled, and any specific values or configurations used. Also explain the test execution architecture.
|
||||
|
||||
## Expected Outputs
|
||||
|
||||
- A markdown file `test-reproducibility.md`
|
||||
|
|
@ -1 +0,0 @@
|
|||
Effort level configuration and candidate selection criteria
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
{
|
||||
"context": "Tests whether the agent knows the specific effort level values for candidate generation and the criteria used to select the best optimization candidate.",
|
||||
"type": "weighted_checklist",
|
||||
"checklist": [
|
||||
{
|
||||
"name": "Candidate counts by effort",
|
||||
"description": "States correct N_OPTIMIZER_CANDIDATES values: LOW=3, MEDIUM=5, HIGH=6 (at least 2 of 3 correct)",
|
||||
"max_score": 25
|
||||
},
|
||||
{
|
||||
"name": "Speedup as primary selector",
|
||||
"description": "States that the winning candidate is selected primarily by highest speedup ratio",
|
||||
"max_score": 25
|
||||
},
|
||||
{
|
||||
"name": "Diff length as tiebreaker",
|
||||
"description": "States that for tied speedups, shortest diff length from original is used as tiebreaker",
|
||||
"max_score": 25
|
||||
},
|
||||
{
|
||||
"name": "Refinement ranking weights",
|
||||
"description": "States that refinement candidates use weighted ranking with runtime weighted more heavily than diff (2:1 ratio or REFINED_CANDIDATE_RANKING_WEIGHTS=(2,1))",
|
||||
"max_score": 25
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
# Design a Candidate Selection Dashboard
|
||||
|
||||
## Context
|
||||
|
||||
The codeflash team wants to build a dashboard that shows users how optimization candidates were evaluated and why a particular candidate won. The dashboard needs to display the selection process at each stage, from initial candidate pool through to the final winner.
|
||||
|
||||
## Task
|
||||
|
||||
Write a specification document for the dashboard that explains:
|
||||
1. How many candidates are generated at each effort level
|
||||
2. The exact criteria and order of operations used to pick the winning candidate
|
||||
3. How refinement candidates are ranked differently from initial candidates
|
||||
|
||||
Include concrete examples showing how two hypothetical candidates would be compared.
|
||||
|
||||
## Expected Outputs
|
||||
|
||||
- A markdown file `selection-dashboard-spec.md`
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue