Merge branch 'main' into part-1-windows-fixes
31
.github/workflows/deploy-docs-to-azure.yaml
vendored
|
|
@ -1,31 +0,0 @@
|
|||
name: Codeflash Docs Publish to Azure Static Web Apps
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**'
|
||||
- '.github/workflows/deploy-docs-to-azure.yaml'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build_and_deploy_job:
|
||||
if: github.event_name == 'push' || (github.event_name == 'pull_request' && github.event.action != 'closed')
|
||||
runs-on: ubuntu-latest
|
||||
name: Build and Deploy Job
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
- name: Build And Deploy
|
||||
id: builddeploy
|
||||
uses: Azure/static-web-apps-deploy@v1
|
||||
with:
|
||||
azure_static_web_apps_api_token: ${{ secrets.AZURE_STATIC_WEB_APPS_API_TOKEN }}
|
||||
repo_token: ${{ secrets.GITHUB_TOKEN }} # Used for GitHub integrations (i.e. PR comments)
|
||||
action: "upload"
|
||||
###### Repository/Build Configurations ######
|
||||
app_location: "docs" # App source code path relative to repository root
|
||||
output_location: "build" # Built app content directory, relative to app_location - optional
|
||||
###### End of Repository/Build Configurations ######
|
||||
1
.gitignore
vendored
|
|
@ -254,3 +254,4 @@ fabric.properties
|
|||
|
||||
# Mac
|
||||
.DS_Store
|
||||
WARP.MD
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: "v0.11.0"
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.7
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
318
AGENTS.md
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
# CodeFlash AI Agent Instructions
|
||||
|
||||
This file provides comprehensive guidance to any coding agent (Warp, GitHub Copilot, Claude, Gemini, etc.) when working with the CodeFlash repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
CodeFlash is an AI-powered Python code optimizer that automatically improves code performance while maintaining correctness. It uses LLMs to analyze code, generate optimization ideas, validate correctness through comprehensive testing, benchmark performance improvements, and create merge-ready pull requests.
|
||||
|
||||
**Key Capabilities:**
|
||||
- Optimize entire codebases with `codeflash --all`
|
||||
- Optimize specific files or functions with targeted commands
|
||||
- End-to-end workflow optimization with `codeflash optimize script.py`
|
||||
- Automated GitHub Actions integration for CI/CD pipelines
|
||||
- Comprehensive benchmarking and performance analysis
|
||||
- Git worktree isolation for safe optimization
|
||||
|
||||
## Core Architecture
|
||||
|
||||
### Data Flow Pipeline
|
||||
Discovery → Context → Optimization → Verification → Benchmarking → PR
|
||||
|
||||
1. **Discovery** (`codeflash/discovery/`) - Find optimizable functions via static analysis or execution tracing
|
||||
2. **Context Extraction** (`codeflash/context/`) - Extract dependencies, imports, and related code
|
||||
3. **Optimization** (`codeflash/optimization/`) - Generate optimized code via AI service calls
|
||||
4. **Verification** (`codeflash/verification/`) - Run deterministic tests with custom pytest plugin
|
||||
5. **Benchmarking** (`codeflash/benchmarking/`) - Performance measurement and comparison
|
||||
6. **GitHub Integration** (`codeflash/github/`) - Automated PR creation with detailed analysis
|
||||
|
||||
### Key Components
|
||||
|
||||
**Main Entry Points:**
|
||||
- `codeflash/main.py` - CLI entry point and main orchestration
|
||||
- `codeflash/cli_cmds/cli.py` - Command-line argument parsing and validation
|
||||
|
||||
**Core Optimization Pipeline:**
|
||||
- `codeflash/optimization/optimizer.py` - Main optimization orchestrator
|
||||
- `codeflash/optimization/function_optimizer.py` - Individual function optimization
|
||||
- `codeflash/tracing/` - Function call tracing and profiling
|
||||
|
||||
**Code Analysis & Manipulation:**
|
||||
- `codeflash/code_utils/` - Code parsing, AST manipulation, static analysis
|
||||
- `codeflash/context/` - Code context extraction and analysis
|
||||
- `codeflash/verification/` - Code correctness verification through testing
|
||||
|
||||
**External Integrations:**
|
||||
- `codeflash/api/aiservice.py` - LLM communication with rate limiting and retries
|
||||
- `codeflash/github/` - GitHub integration for PR creation
|
||||
- `codeflash/benchmarking/` - Performance benchmarking and measurement
|
||||
|
||||
**Supporting Systems:**
|
||||
- `codeflash/models/models.py` - Pydantic models and type definitions
|
||||
- `codeflash/telemetry/` - Usage analytics (PostHog) and error reporting (Sentry)
|
||||
- `codeflash/ui/` - User interface components (Rich console output)
|
||||
- `codeflash/lsp/` - Language Server Protocol support for IDE integration
|
||||
|
||||
### Key Optimization Workflows
|
||||
|
||||
**1. Full Codebase Optimization (`--all`)**
|
||||
- Discovers all optimizable functions in the project
|
||||
- Runs benchmarks if configured
|
||||
- Optimizes functions in parallel
|
||||
- Creates PRs for successful optimizations
|
||||
|
||||
**2. Targeted Optimization (`--file`, `--function`)**
|
||||
- Focuses on specific files or functions
|
||||
- Performs detailed analysis and context extraction
|
||||
- Applies targeted optimizations
|
||||
|
||||
**3. Workflow Tracing (`optimize`)**
|
||||
- Traces Python script execution
|
||||
- Identifies performance bottlenecks
|
||||
- Generates optimizations for traced functions
|
||||
- Uses checkpoint system to resume interrupted runs
|
||||
|
||||
## Critical Development Patterns
|
||||
|
||||
### Package Management with uv (NOT pip)
|
||||
```bash
|
||||
# Always use uv, never pip
|
||||
uv sync # Install dependencies
|
||||
uv sync --group dev # Install dev dependencies
|
||||
uv run pytest # Run commands
|
||||
uv add package # Add new packages
|
||||
uv build # Build package
|
||||
```
|
||||
|
||||
### Code Manipulation with LibCST (NOT ast)
|
||||
Always use `libcst` for code parsing/modification to preserve formatting:
|
||||
```python
|
||||
from libcst import parse_module, PartialPythonCodeGen
|
||||
# Never use ast module for code transformations
|
||||
```
|
||||
|
||||
### Testing with Deterministic Execution
|
||||
Custom pytest plugin (`codeflash/verification/pytest_plugin.py`) ensures reproducible tests:
|
||||
- Patches time, random, uuid for deterministic behavior
|
||||
- Environment variables: `CODEFLASH_TEST_MODULE`, `CODEFLASH_TEST_CLASS`, `CODEFLASH_TEST_FUNCTION`
|
||||
- Always use `uv run pytest`, never `python -m pytest`
|
||||
|
||||
### Git Worktree Isolation
|
||||
Optimizations run in isolated git worktrees to avoid affecting main repo:
|
||||
```python
|
||||
from codeflash.code_utils.git_utils import create_detached_worktree, remove_worktree
|
||||
# Pattern: create_detached_worktree() → optimize → create_diff_patch_from_worktree()
|
||||
```
|
||||
|
||||
### Error Handling with Either Pattern
|
||||
Use functional error handling instead of exceptions:
|
||||
```python
|
||||
from codeflash.either import is_successful, Either
|
||||
result = aiservice_client.call_llm(...)
|
||||
if is_successful(result):
|
||||
optimized_code = result.value
|
||||
else:
|
||||
error = result.error
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
All configuration in `pyproject.toml` under `[tool.codeflash]`:
|
||||
```toml
|
||||
[tool.codeflash]
|
||||
module-root = "codeflash" # Source code location
|
||||
tests-root = "tests" # Test directory
|
||||
benchmarks-root = "tests/benchmarks" # Benchmark tests
|
||||
test-framework = "pytest" # Always pytest
|
||||
formatter-cmds = [ # Auto-formatting commands
|
||||
"uvx ruff check --exit-zero --fix $file",
|
||||
"uvx ruff format $file",
|
||||
]
|
||||
```
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Environment Setup
|
||||
```bash
|
||||
# Install dependencies (always use uv)
|
||||
uv sync
|
||||
|
||||
# Install development dependencies
|
||||
uv sync --group dev
|
||||
|
||||
# Install pre-commit hooks
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
### Code Quality & Linting
|
||||
```bash
|
||||
# Run linting and formatting with ruff (primary tool)
|
||||
uv run ruff check --fix .
|
||||
uv run ruff format .
|
||||
|
||||
# Type checking with mypy (strict mode)
|
||||
uv run mypy .
|
||||
|
||||
# Clean Python cache files
|
||||
uvx pyclean .
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run all tests
|
||||
uv run pytest
|
||||
|
||||
# Run tests with coverage
|
||||
uv run coverage run -m pytest tests/
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest tests/test_code_utils.py
|
||||
|
||||
# Run tests with verbose output
|
||||
uv run pytest -v
|
||||
|
||||
# Run benchmarks
|
||||
uv run pytest tests/benchmarks/
|
||||
|
||||
# Run end-to-end tests
|
||||
uv run pytest tests/scripts/
|
||||
|
||||
# Run with specific markers
|
||||
uv run pytest -m "not ci_skip"
|
||||
```
|
||||
|
||||
### Running CodeFlash
|
||||
```bash
|
||||
# Initialize CodeFlash in a project
|
||||
uv run codeflash init
|
||||
|
||||
# Optimize entire codebase
|
||||
uv run codeflash --all
|
||||
|
||||
# Optimize specific file
|
||||
uv run codeflash --file path/to/file.py
|
||||
|
||||
# Optimize specific function
|
||||
uv run codeflash --file path/to/file.py --function function_name
|
||||
|
||||
# Trace and optimize a workflow
|
||||
uv run codeflash optimize script.py
|
||||
|
||||
# Verify setup with test optimization
|
||||
uv run codeflash --verify-setup
|
||||
|
||||
# Run with verbose logging
|
||||
uv run codeflash --verbose --all
|
||||
|
||||
# Run with benchmarking enabled
|
||||
uv run codeflash --benchmark --file target_file.py
|
||||
|
||||
# Use replay tests for debugging
|
||||
uv run codeflash --replay-test tests/specific_test.py
|
||||
```
|
||||
|
||||
## Development Guidelines
|
||||
|
||||
### Code Style
|
||||
- Uses Ruff for linting and formatting (configured in pyproject.toml)
|
||||
- Strict mypy type checking enabled
|
||||
- Pre-commit hooks enforce code quality
|
||||
- Line length: 120 characters
|
||||
- Python 3.10+ syntax
|
||||
|
||||
### Testing Strategy
|
||||
- Primary test framework: pytest
|
||||
- Tests located in `tests/` directory
|
||||
- End-to-end tests in `tests/scripts/`
|
||||
- Benchmarks in `tests/benchmarks/`
|
||||
- Extensive use of `@pytest.mark.parametrize`
|
||||
- Shared fixtures in conftest.py
|
||||
- Test isolation via custom pytest plugin
|
||||
|
||||
### Key Dependencies
|
||||
- **Core**: `libcst`, `jedi`, `gitpython`, `pydantic`
|
||||
- **Testing**: `pytest`, `coverage`, `crosshair-tool`
|
||||
- **Performance**: `line_profiler`, `timeout-decorator`
|
||||
- **UI**: `rich`, `inquirer`, `click`
|
||||
- **AI**: Custom API client for LLM interactions
|
||||
|
||||
### Data Models & Types
|
||||
- `codeflash/models/models.py` - Pydantic models for all data structures
|
||||
- Extensive use of `@dataclass(frozen=True)` for immutable data
|
||||
- Core types: `FunctionToOptimize`, `ValidCode`, `BenchmarkKey`
|
||||
|
||||
## AI Service Integration
|
||||
|
||||
### Rate Limiting & Retries
|
||||
- Built-in rate limiting and exponential backoff
|
||||
- Handle `Either` return types for error handling
|
||||
- AI service endpoint: `codeflash/api/aiservice.py`
|
||||
|
||||
### Telemetry & Monitoring
|
||||
- **Sentry**: Error tracking with `codeflash.telemetry.sentry`
|
||||
- **PostHog**: Usage analytics with `codeflash.telemetry.posthog_cf`
|
||||
- **Environment Variables**: `CODEFLASH_EXPERIMENT_ID` for testing modes
|
||||
|
||||
## Performance & Benchmarking
|
||||
|
||||
### Line Profiler Integration
|
||||
- Uses `line_profiler` for detailed performance analysis
|
||||
- Instruments functions with `@profile` decorator
|
||||
- Generates before/after profiling reports
|
||||
- Calculates precise speedup measurements
|
||||
|
||||
### Benchmark Test Framework
|
||||
- Custom benchmarking in `tests/benchmarks/`
|
||||
- Generates replay tests from execution traces
|
||||
- Validates performance improvements statistically
|
||||
|
||||
## Debugging & Development
|
||||
|
||||
### Verbose Logging
|
||||
```bash
|
||||
uv run codeflash --verbose --file target_file.py
|
||||
```
|
||||
|
||||
### Important Environment Variables
|
||||
- `CODEFLASH_TEST_MODULE` - Current test module during verification
|
||||
- `CODEFLASH_TEST_CLASS` - Current test class during verification
|
||||
- `CODEFLASH_TEST_FUNCTION` - Current test function during verification
|
||||
- `CODEFLASH_LOOP_INDEX` - Current iteration in pytest loops
|
||||
- `CODEFLASH_EXPERIMENT_ID` - Enables local AI service for testing
|
||||
|
||||
### LSP Integration
|
||||
Language Server Protocol support in `codeflash/lsp/` enables IDE integration during optimization.
|
||||
|
||||
### Common Debugging Patterns
|
||||
1. Use verbose logging to trace optimization flow
|
||||
2. Check git worktree operations for isolation issues
|
||||
3. Verify deterministic test execution with environment variables
|
||||
4. Use replay tests to debug specific optimization scenarios
|
||||
5. Monitor AI service calls with rate limiting logs
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Path Handling
|
||||
- Always use absolute paths
|
||||
- Handle encoding explicitly (UTF-8)
|
||||
- Extensive path validation and cleanup utilities in `codeflash/code_utils/`
|
||||
|
||||
### Git Operations
|
||||
- All optimizations run in isolated worktrees
|
||||
- Never modify the main repository directly
|
||||
- Use git utilities in `codeflash/code_utils/git_utils.py`
|
||||
|
||||
### Code Transformations
|
||||
- Always use libcst, never ast module
|
||||
- Preserve code formatting and comments
|
||||
- Validate transformations with deterministic tests
|
||||
|
||||
### Error Handling
|
||||
- Use Either pattern for functional error handling
|
||||
- Log errors to Sentry for monitoring
|
||||
- Provide clear user feedback via Rich console
|
||||
|
||||
### Performance Optimization
|
||||
- Profile before and after changes
|
||||
- Use benchmarks to validate improvements
|
||||
- Generate detailed performance reports
|
||||
15
README.md
|
|
@ -3,9 +3,7 @@
|
|||
<a href="https://github.com/codeflash-ai/codeflash">
|
||||
<img src="https://img.shields.io/github/commit-activity/m/codeflash-ai/codeflash" alt="GitHub commit activity">
|
||||
</a>
|
||||
<a href="https://pypi.org/project/codeflash/">
|
||||
<img src="https://img.shields.io/pypi/dm/codeflash" alt="PyPI Downloads">
|
||||
</a>
|
||||
<a href="https://pypi.org/project/codeflash/"><img src="https://static.pepy.tech/badge/codeflash" alt="PyPI Downloads"></a>
|
||||
<a href="https://pypi.org/project/codeflash/">
|
||||
<img src="https://img.shields.io/pypi/v/codeflash?label=PyPI%20version" alt="PyPI Downloads">
|
||||
</a>
|
||||
|
|
@ -19,7 +17,7 @@ How to use Codeflash -
|
|||
- Automate optimizing all __future__ code you will write by installing Codeflash as a GitHub action.
|
||||
- Optimize a Python workflow `python myscript.py` end-to-end by running `codeflash optimize myscript.py`
|
||||
|
||||
Codeflash is used by top engineering teams at [Pydantic](https://github.com/pydantic/pydantic/pulls?q=is%3Apr+author%3Amisrasaurabh1+is%3Amerged), [Langflow](https://github.com/langflow-ai/langflow/issues?q=state%3Aclosed%20is%3Apr%20author%3Amisrasaurabh1), [Roboflow](https://github.com/roboflow/inference/pulls?q=is%3Apr+is%3Amerged+codeflash+sort%3Acreated-asc), [Albumentations](https://github.com/albumentations-team/albumentations/issues?q=state%3Amerged%20is%3Apr%20author%3Akrrt7%20OR%20state%3Amerged%20is%3Apr%20author%3Aaseembits93%20) and many others to ship performant, expert level code.
|
||||
Codeflash is used by top engineering teams at **Pydantic** [(PRs Merged)](https://github.com/pydantic/pydantic/pulls?q=is%3Apr+author%3Amisrasaurabh1+is%3Amerged), **Roboflow** [(PRs Merged 1](https://github.com/roboflow/inference/issues?q=state%3Aclosed%20is%3Apr%20author%3Amisrasaurabh1%20is%3Amerged), [PRs Merged 2)](https://github.com/roboflow/inference/issues?q=state%3Amerged%20is%3Apr%20author%3Acodeflash-ai%5Bbot%5D), **Unstructured** [(PRs Merged 1](https://github.com/Unstructured-IO/unstructured/pulls?q=is%3Apr+Explanation+and+details+in%3Abody+is%3Amerged), [PRs Merged 2)](https://github.com/Unstructured-IO/unstructured-ingest/pulls?q=is%3Apr+Explanation+and+details+in%3Abody+is%3Amerged), **Langflow** [(PRs Merged)](https://github.com/langflow-ai/langflow/issues?q=state%3Aclosed%20is%3Apr%20author%3Amisrasaurabh1) and many others to ship performant, expert level code.
|
||||
|
||||
Codeflash is great at optimizing AI Agents, Computer Vision algorithms, PyTorch code, numerical code, backend code or anything else you might write with Python.
|
||||
|
||||
|
|
@ -65,6 +63,13 @@ For detailed installation and usage instructions, visit our documentation at [do
|
|||
|
||||
https://github.com/user-attachments/assets/38f44f4e-be1c-4f84-8db9-63d5ee3e61e5
|
||||
|
||||
- Optiming a workflow end to end automatically with `codeflash optimize`
|
||||
|
||||
|
||||
https://github.com/user-attachments/assets/355ba295-eb5a-453a-8968-7fb35c70d16c
|
||||
|
||||
|
||||
|
||||
## Support
|
||||
|
||||
Join our community for support and discussions. If you have any questions, feel free to reach out to us using one of the following methods:
|
||||
|
|
@ -76,4 +81,4 @@ Join our community for support and discussions. If you have any questions, feel
|
|||
|
||||
## License
|
||||
|
||||
Codeflash is licensed under the BSL-1.1 License. See the LICENSE file for details.
|
||||
Codeflash is licensed under the BSL-1.1 License. See the [LICENSE](https://github.com/codeflash-ai/codeflash/blob/main/codeflash/LICENSE) file for details.
|
||||
|
|
|
|||
19
SECURITY.md
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Security Policy
|
||||
|
||||
This document outlines Codeflash's vulnerability disclosure policy. For more information about Codeflash's approach to security, please visit [codeflash.ai/security](https://www.codeflash.ai/security).
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Since Codeflash is moving quickly, we can only commit to fixing security issues for the latest version of codeflash client.
|
||||
If a vulnerability is discovered in our backend, we will release the fix for all the users.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
|
||||
Please do not report security vulnerabilities through public GitHub issues.
|
||||
|
||||
Instead, please report them to our [GitHub Security page](https://github.com/codeflash-ai/codeflash/security). If you prefer to submit one without using GitHub, you can also email us at security@codeflash.ai.
|
||||
|
||||
We commit to acknowledging vulnerability reports immediately, and will work to fix active vulnerabilities as soon as we can. We will publish resolved vulnerabilities in the form of security advisories on our GitHub security page. Critical incidents will be communicated both on the GitHub security page and via email to all affected users.
|
||||
|
||||
We appreciate your help in making Codeflash more secure for everyone. Thank you for your support and responsible disclosure.
|
||||
|
|
@ -1,8 +1,2 @@
|
|||
DEFAULT_API_URL = "https://api.galileo.ai/"
|
||||
DEFAULT_APP_URL = "https://app.galileo.ai/"
|
||||
|
||||
|
||||
# function_names: GalileoApiClient.get_console_url
|
||||
# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py
|
||||
# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))}
|
||||
# project_root_path: /home/mohammed/Work/galileo-python/src
|
||||
|
|
|
|||
|
|
@ -8,21 +8,54 @@ from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
|||
|
||||
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
|
||||
|
||||
benchmark_options = [
|
||||
("--benchmark-columns", "store", None, "Benchmark columns"),
|
||||
("--benchmark-group-by", "store", None, "Benchmark group by"),
|
||||
("--benchmark-name", "store", None, "Benchmark name pattern"),
|
||||
("--benchmark-sort", "store", None, "Benchmark sort column"),
|
||||
("--benchmark-json", "store", None, "Benchmark JSON output file"),
|
||||
("--benchmark-save", "store", None, "Benchmark save name"),
|
||||
("--benchmark-warmup", "store", None, "Benchmark warmup"),
|
||||
("--benchmark-warmup-iterations", "store", None, "Benchmark warmup iterations"),
|
||||
("--benchmark-min-time", "store", None, "Benchmark minimum time"),
|
||||
("--benchmark-max-time", "store", None, "Benchmark maximum time"),
|
||||
("--benchmark-min-rounds", "store", None, "Benchmark minimum rounds"),
|
||||
("--benchmark-timer", "store", None, "Benchmark timer"),
|
||||
("--benchmark-calibration-precision", "store", None, "Benchmark calibration precision"),
|
||||
("--benchmark-disable", "store_true", False, "Disable benchmarks"),
|
||||
("--benchmark-skip", "store_true", False, "Skip benchmarks"),
|
||||
("--benchmark-only", "store_true", False, "Only run benchmarks"),
|
||||
("--benchmark-verbose", "store_true", False, "Verbose benchmark output"),
|
||||
("--benchmark-histogram", "store", None, "Benchmark histogram"),
|
||||
("--benchmark-compare", "store", None, "Benchmark compare"),
|
||||
("--benchmark-compare-fail", "store", None, "Benchmark compare fail threshold"),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
"""Register the benchmark marker and disable conflicting plugins."""
|
||||
config.addinivalue_line("markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing")
|
||||
|
||||
if config.getoption("--codeflash-trace") and PYTEST_BENCHMARK_INSTALLED:
|
||||
config.option.benchmark_disable = True
|
||||
config.pluginmanager.set_blocked("pytest_benchmark")
|
||||
config.pluginmanager.set_blocked("pytest-benchmark")
|
||||
if config.getoption("--codeflash-trace"):
|
||||
# When --codeflash-trace is used, ignore all benchmark options by resetting them to defaults
|
||||
for option, _, default, _ in benchmark_options:
|
||||
option_name = option.replace("--", "").replace("-", "_")
|
||||
if hasattr(config.option, option_name):
|
||||
setattr(config.option, option_name, default)
|
||||
|
||||
if PYTEST_BENCHMARK_INSTALLED:
|
||||
config.pluginmanager.set_blocked("pytest_benchmark")
|
||||
config.pluginmanager.set_blocked("pytest-benchmark")
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
parser.addoption(
|
||||
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
|
||||
)
|
||||
# These options are ignored when --codeflash-trace is used
|
||||
for option, action, default, help_text in benchmark_options:
|
||||
help_suffix = " (ignored when --codeflash-trace is used)"
|
||||
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -37,7 +70,7 @@ def benchmark(request: pytest.FixtureRequest) -> object:
|
|||
# If pytest-benchmark is installed and --codeflash-trace is not enabled,
|
||||
# return the normal pytest-benchmark fixture
|
||||
if PYTEST_BENCHMARK_INSTALLED:
|
||||
from pytest_benchmark.fixture import BenchmarkFixture as BSF # noqa: N814
|
||||
from pytest_benchmark.fixture import BenchmarkFixture as BSF # pyright: ignore[reportMissingImports] # noqa: I001, N814
|
||||
|
||||
bs = getattr(config, "_benchmarksession", None)
|
||||
if bs and bs.skip:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "codeflash-benchmark"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "Pytest benchmarking plugin for codeflash.ai - automatic code performance optimization"
|
||||
authors = [{ name = "CodeFlash Inc.", email = "contact@codeflash.ai" }]
|
||||
requires-python = ">=3.9"
|
||||
|
|
@ -25,8 +25,8 @@ Repository = "https://github.com/codeflash-ai/codeflash-benchmark"
|
|||
codeflash-benchmark = "codeflash_benchmark.plugin"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel", "setuptools_scm"]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["codeflash_benchmark"]
|
||||
packages = ["codeflash_benchmark"]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Business Source License 1.1
|
|||
Parameters
|
||||
|
||||
Licensor: CodeFlash Inc.
|
||||
Licensed Work: Codeflash Client version 0.15.x
|
||||
Licensed Work: Codeflash Client version 0.17.x
|
||||
The Licensed Work is (c) 2024 CodeFlash Inc.
|
||||
|
||||
Additional Use Grant: None. Production use of the Licensed Work is only permitted
|
||||
|
|
@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
|
|||
Platform. Please visit codeflash.ai for further
|
||||
information.
|
||||
|
||||
Change Date: 2029-07-03
|
||||
Change Date: 2029-09-23
|
||||
|
||||
Change License: MIT
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,12 @@ import requests
|
|||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
|
||||
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE, N_CANDIDATES_LP_EFFECTIVE
|
||||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate
|
||||
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.version import __version__ as codeflash_version
|
||||
|
||||
|
|
@ -80,6 +82,19 @@ class AiServiceClient:
|
|||
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
|
||||
return response
|
||||
|
||||
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
|
||||
candidates: list[OptimizedCandidate] = []
|
||||
for opt in optimizations_json:
|
||||
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
|
||||
if not code.code_strings:
|
||||
continue
|
||||
candidates.append(
|
||||
OptimizedCandidate(
|
||||
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
def optimize_python_code( # noqa: D417
|
||||
self,
|
||||
source_code: str,
|
||||
|
|
@ -117,9 +132,10 @@ class AiServiceClient:
|
|||
"current_username": get_last_commit_author_if_pr_exists(None),
|
||||
"repo_owner": git_repo_owner,
|
||||
"repo_name": git_repo_name,
|
||||
"n_candidates": N_CANDIDATES_EFFECTIVE,
|
||||
}
|
||||
|
||||
logger.info("Generating optimized candidates…")
|
||||
logger.info("!lsp|Generating optimized candidates…")
|
||||
console.rule()
|
||||
try:
|
||||
response = self.make_ai_service_request("/optimize", payload=payload, timeout=600)
|
||||
|
|
@ -130,18 +146,11 @@ class AiServiceClient:
|
|||
|
||||
if response.status_code == 200:
|
||||
optimizations_json = response.json()["optimizations"]
|
||||
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
|
||||
logger.info(f"!lsp|Generated {len(optimizations_json)} candidate optimizations.")
|
||||
console.rule()
|
||||
end_time = time.perf_counter()
|
||||
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=opt["source_code"],
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
)
|
||||
for opt in optimizations_json
|
||||
]
|
||||
logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.")
|
||||
return self._get_valid_candidates(optimizations_json)
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
@ -185,9 +194,9 @@ class AiServiceClient:
|
|||
"experiment_metadata": experiment_metadata,
|
||||
"codeflash_version": codeflash_version,
|
||||
"lsp_mode": is_LSP_enabled(),
|
||||
"n_candidates_lp": N_CANDIDATES_LP_EFFECTIVE,
|
||||
}
|
||||
|
||||
logger.info("Generating optimized candidates…")
|
||||
console.rule()
|
||||
if line_profiler_results == "":
|
||||
logger.info("No LineProfiler results were provided, Skipping optimization.")
|
||||
|
|
@ -202,16 +211,11 @@ class AiServiceClient:
|
|||
|
||||
if response.status_code == 200:
|
||||
optimizations_json = response.json()["optimizations"]
|
||||
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
|
||||
logger.info(
|
||||
f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information."
|
||||
)
|
||||
console.rule()
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=opt["source_code"],
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
)
|
||||
for opt in optimizations_json
|
||||
]
|
||||
return self._get_valid_candidates(optimizations_json)
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
@ -248,7 +252,7 @@ class AiServiceClient:
|
|||
}
|
||||
for opt in request
|
||||
]
|
||||
logger.info(f"Refining {len(request)} optimizations…")
|
||||
logger.debug(f"Refining {len(request)} optimizations…")
|
||||
console.rule()
|
||||
try:
|
||||
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
|
||||
|
|
@ -259,16 +263,19 @@ class AiServiceClient:
|
|||
|
||||
if response.status_code == 200:
|
||||
refined_optimizations = response.json()["refinements"]
|
||||
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
|
||||
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
|
||||
console.rule()
|
||||
|
||||
refinements = self._get_valid_candidates(refined_optimizations)
|
||||
return [
|
||||
OptimizedCandidate(
|
||||
source_code=opt["source_code"],
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"][:-4] + "refi",
|
||||
source_code=c.source_code,
|
||||
explanation=c.explanation,
|
||||
optimization_id=c.optimization_id[:-4] + "refi",
|
||||
)
|
||||
for opt in refined_optimizations
|
||||
for c in refinements
|
||||
]
|
||||
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
|
|
@ -278,6 +285,123 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return []
|
||||
|
||||
def get_new_explanation( # noqa: D417
|
||||
self,
|
||||
source_code: str,
|
||||
optimized_code: str,
|
||||
dependency_code: str,
|
||||
trace_id: str,
|
||||
original_line_profiler_results: str,
|
||||
optimized_line_profiler_results: str,
|
||||
original_code_runtime: str,
|
||||
optimized_code_runtime: str,
|
||||
speedup: str,
|
||||
annotated_tests: str,
|
||||
optimization_id: str,
|
||||
original_explanation: str,
|
||||
) -> str:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- source_code (str): The python code to optimize.
|
||||
- optimized_code (str): The python code generated by the AI service.
|
||||
- dependency_code (str): The dependency code used as read-only context for the optimization
|
||||
- original_line_profiler_results: str - line profiler results for the baseline code
|
||||
- optimized_line_profiler_results: str - line profiler results for the optimized code
|
||||
- original_code_runtime: str - runtime for the baseline code
|
||||
- optimized_code_runtime: str - runtime for the optimized code
|
||||
- speedup: str - speedup of the optimized code
|
||||
- annotated_tests: str - test functions annotated with runtime
|
||||
- optimization_id: str - unique id of opt candidate
|
||||
- original_explanation: str - original_explanation generated for the opt candidate
|
||||
|
||||
Returns
|
||||
-------
|
||||
- List[OptimizationCandidate]: A list of Optimization Candidates.
|
||||
|
||||
"""
|
||||
payload = {
|
||||
"trace_id": trace_id,
|
||||
"source_code": source_code,
|
||||
"optimized_code": optimized_code,
|
||||
"original_line_profiler_results": original_line_profiler_results,
|
||||
"optimized_line_profiler_results": optimized_line_profiler_results,
|
||||
"original_code_runtime": original_code_runtime,
|
||||
"optimized_code_runtime": optimized_code_runtime,
|
||||
"speedup": speedup,
|
||||
"annotated_tests": annotated_tests,
|
||||
"optimization_id": optimization_id,
|
||||
"original_explanation": original_explanation,
|
||||
"dependency_code": dependency_code,
|
||||
}
|
||||
logger.info("loading|Generating explanation")
|
||||
console.rule()
|
||||
try:
|
||||
response = self.make_ai_service_request("/explain", payload=payload, timeout=60)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating explanations: {e}")
|
||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||
return ""
|
||||
|
||||
if response.status_code == 200:
|
||||
explanation: str = response.json()["explanation"]
|
||||
console.rule()
|
||||
return explanation
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
error = response.text
|
||||
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
|
||||
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
console.rule()
|
||||
return ""
|
||||
|
||||
def generate_ranking( # noqa: D417
|
||||
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
|
||||
) -> list[int] | None:
|
||||
"""Optimize the given python code for performance by making a request to the Django endpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- trace_id : unique uuid of function
|
||||
- diffs : list of unified diff strings of opt candidates
|
||||
- speedups : list of speedups of opt candidates
|
||||
|
||||
Returns
|
||||
-------
|
||||
- List[int]: Ranking of opt candidates in decreasing order
|
||||
|
||||
"""
|
||||
payload = {
|
||||
"trace_id": trace_id,
|
||||
"diffs": diffs,
|
||||
"speedups": speedups,
|
||||
"optimization_ids": optimization_ids,
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
logger.info("loading|Generating ranking")
|
||||
console.rule()
|
||||
try:
|
||||
response = self.make_ai_service_request("/rank", payload=payload, timeout=60)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.exception(f"Error generating ranking: {e}")
|
||||
ph("cli-optimize-error-caught", {"error": str(e)})
|
||||
return None
|
||||
|
||||
if response.status_code == 200:
|
||||
ranking: list[int] = response.json()["ranking"]
|
||||
console.rule()
|
||||
return ranking
|
||||
try:
|
||||
error = response.json()["error"]
|
||||
except Exception:
|
||||
error = response.text
|
||||
logger.error(f"Error generating ranking: {response.status_code} - {error}")
|
||||
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
console.rule()
|
||||
return None
|
||||
|
||||
def log_results( # noqa: D417
|
||||
self,
|
||||
function_trace_id: str,
|
||||
|
|
@ -287,6 +411,7 @@ class AiServiceClient:
|
|||
is_correct: dict[str, bool] | None,
|
||||
optimized_line_profiler_results: dict[str, str] | None,
|
||||
metadata: dict[str, Any] | None,
|
||||
optimizations_post: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Log features to the database.
|
||||
|
||||
|
|
@ -299,6 +424,7 @@ class AiServiceClient:
|
|||
- is_correct (Optional[Dict[str, bool]]): Whether the optimized code is correct.
|
||||
- optimized_line_profiler_results: line_profiler results for every candidate mapped to their optimization_id
|
||||
- metadata: contains the best optimization id
|
||||
- optimizations_post - dict mapping opt id to code str after postprocessing
|
||||
|
||||
"""
|
||||
payload = {
|
||||
|
|
@ -310,6 +436,7 @@ class AiServiceClient:
|
|||
"codeflash_version": codeflash_version,
|
||||
"optimized_line_profiler_results": optimized_line_profiler_results,
|
||||
"metadata": metadata,
|
||||
"optimizations_post": optimizations_post,
|
||||
}
|
||||
try:
|
||||
self.make_ai_service_request("/log_features", payload=payload, timeout=5)
|
||||
|
|
|
|||
|
|
@ -14,16 +14,19 @@ from pydantic.json import pydantic_encoder
|
|||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
|
||||
from codeflash.code_utils.git_utils import get_repo_owner_and_name
|
||||
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
|
||||
from codeflash.github.PrComment import FileDiffContent, PrComment
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.version import __version__
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from requests import Response
|
||||
|
||||
from codeflash.github.PrComment import FileDiffContent, PrComment
|
||||
from codeflash.result.explanation import Explanation
|
||||
|
||||
from packaging import version
|
||||
|
||||
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
|
||||
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
|
||||
CFAPI_BASE_URL = "http://localhost:3001"
|
||||
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
|
||||
console.rule()
|
||||
|
|
@ -37,6 +40,7 @@ def make_cfapi_request(
|
|||
payload: dict[str, Any] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
suppress_errors: bool = False,
|
||||
) -> Response:
|
||||
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
||||
|
|
@ -48,7 +52,7 @@ def make_cfapi_request(
|
|||
:return: The response object from the API.
|
||||
"""
|
||||
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
|
||||
cfapi_headers = {"Authorization": f"Bearer {get_codeflash_api_key()}"}
|
||||
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
|
||||
if extra_headers:
|
||||
cfapi_headers.update(extra_headers)
|
||||
try:
|
||||
|
|
@ -80,7 +84,7 @@ def make_cfapi_request(
|
|||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_user_id() -> Optional[str]:
|
||||
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
|
||||
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
||||
|
||||
:return: The userid or None if the request fails.
|
||||
|
|
@ -88,7 +92,9 @@ def get_user_id() -> Optional[str]:
|
|||
if not ensure_codeflash_api_key():
|
||||
return None
|
||||
|
||||
response = make_cfapi_request(endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__})
|
||||
response = make_cfapi_request(
|
||||
endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__}, api_key=api_key
|
||||
)
|
||||
if response.status_code == 200:
|
||||
if "min_version" not in response.text:
|
||||
return response.text
|
||||
|
|
@ -99,6 +105,9 @@ def get_user_id() -> Optional[str]:
|
|||
if min_version and version.parse(min_version) > version.parse(__version__):
|
||||
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
|
||||
console.print(f"[bold red]{msg}[/bold red]")
|
||||
if is_LSP_enabled():
|
||||
logger.debug(msg)
|
||||
return f"Error: {msg}"
|
||||
sys.exit(1)
|
||||
return userid
|
||||
|
||||
|
|
@ -119,6 +128,8 @@ def suggest_changes(
|
|||
generated_tests: str,
|
||||
trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str = "",
|
||||
concolic_tests: str = "",
|
||||
) -> Response:
|
||||
"""Suggest changes to a pull request.
|
||||
|
||||
|
|
@ -142,6 +153,8 @@ def suggest_changes(
|
|||
"generatedTests": generated_tests,
|
||||
"traceId": trace_id,
|
||||
"coverage_message": coverage_message,
|
||||
"replayTests": replay_tests,
|
||||
"concolicTests": concolic_tests,
|
||||
}
|
||||
return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload)
|
||||
|
||||
|
|
@ -156,6 +169,8 @@ def create_pr(
|
|||
generated_tests: str,
|
||||
trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str = "",
|
||||
concolic_tests: str = "",
|
||||
) -> Response:
|
||||
"""Create a pull request, targeting the specified branch. (usually 'main').
|
||||
|
||||
|
|
@ -178,10 +193,68 @@ def create_pr(
|
|||
"generatedTests": generated_tests,
|
||||
"traceId": trace_id,
|
||||
"coverage_message": coverage_message,
|
||||
"replayTests": replay_tests,
|
||||
"concolicTests": concolic_tests,
|
||||
}
|
||||
return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload)
|
||||
|
||||
|
||||
def create_staging(
|
||||
original_code: dict[Path, str],
|
||||
new_code: dict[Path, str],
|
||||
explanation: Explanation,
|
||||
existing_tests_source: str,
|
||||
generated_original_test_source: str,
|
||||
function_trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str,
|
||||
concolic_tests: str,
|
||||
root_dir: Path,
|
||||
) -> Response:
|
||||
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
|
||||
|
||||
:param original_code: A mapping of file paths to original source code.
|
||||
:param new_code: A mapping of file paths to optimized source code.
|
||||
:param explanation: An Explanation object with optimization details.
|
||||
:param existing_tests_source: Existing test code.
|
||||
:param generated_original_test_source: Generated tests for the original function.
|
||||
:param function_trace_id: Unique identifier for this optimization trace.
|
||||
:param coverage_message: Coverage report or summary.
|
||||
:return: The response object from the backend.
|
||||
"""
|
||||
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
|
||||
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
|
||||
for p in original_code
|
||||
}
|
||||
|
||||
payload = {
|
||||
"baseBranch": get_current_branch(),
|
||||
"diffContents": build_file_changes,
|
||||
"prCommentFields": PrComment(
|
||||
optimization_explanation=explanation.explanation_message(),
|
||||
best_runtime=explanation.best_runtime_ns,
|
||||
original_runtime=explanation.original_runtime_ns,
|
||||
function_name=explanation.function_name,
|
||||
relative_file_path=relative_path,
|
||||
speedup_x=explanation.speedup_x,
|
||||
speedup_pct=explanation.speedup_pct,
|
||||
winning_behavior_test_results=explanation.winning_behavior_test_results,
|
||||
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
|
||||
benchmark_details=explanation.benchmark_details,
|
||||
).to_json(),
|
||||
"existingTests": existing_tests_source,
|
||||
"generatedTests": generated_original_test_source,
|
||||
"traceId": function_trace_id,
|
||||
"coverage_message": coverage_message,
|
||||
"replayTests": replay_tests,
|
||||
"concolicTests": concolic_tests,
|
||||
}
|
||||
|
||||
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)
|
||||
|
||||
|
||||
def is_github_app_installed_on_repo(owner: str, repo: str, *, suppress_errors: bool = False) -> bool:
|
||||
"""Check if the Codeflash GitHub App is installed on the specified repository.
|
||||
|
||||
|
|
|
|||
|
|
@ -32,12 +32,14 @@ class AddDecoratorTransformer(cst.CSTTransformer):
|
|||
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
|
||||
if self.class_name: # Don't go into nested class
|
||||
return False
|
||||
self.class_name = node.name.value # noqa: RET503
|
||||
self.class_name = node.name.value
|
||||
return None
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
|
||||
if self.function_name: # Don't go into nested function
|
||||
return False
|
||||
self.function_name = node.name.value # noqa: RET503
|
||||
self.function_name = node.name.value
|
||||
return None
|
||||
|
||||
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
|
||||
if self.function_name == original_node.name.value:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from codeflash.cli_cmds.console import logger
|
|||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
|
||||
|
|
@ -77,6 +78,7 @@ def parse_args() -> Namespace:
|
|||
parser.add_argument(
|
||||
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
|
||||
)
|
||||
parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review")
|
||||
parser.add_argument(
|
||||
"--verify-setup",
|
||||
action="store_true",
|
||||
|
|
@ -93,6 +95,7 @@ def parse_args() -> Namespace:
|
|||
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
|
||||
)
|
||||
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
|
||||
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
|
|
@ -209,6 +212,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
if args.benchmarks_root:
|
||||
args.benchmarks_root = Path(args.benchmarks_root).resolve()
|
||||
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
|
||||
if is_LSP_enabled():
|
||||
args.all = None
|
||||
return args
|
||||
return handle_optimize_all_arg_parsing(args)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
|||
from argparse import Namespace
|
||||
|
||||
CODEFLASH_LOGO: str = (
|
||||
f"{LF}" # noqa: ISC003
|
||||
f"{LF}"
|
||||
r" _ ___ _ _ " + f"{LF}"
|
||||
r" | | / __)| | | | " + f"{LF}"
|
||||
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"
|
||||
|
|
@ -85,12 +85,16 @@ def init_codeflash() -> None:
|
|||
|
||||
did_add_new_key = prompt_api_key()
|
||||
|
||||
if should_modify_pyproject_toml():
|
||||
setup_info: SetupInfo = collect_setup_info()
|
||||
should_modify, config = should_modify_pyproject_toml()
|
||||
|
||||
git_remote = config.get("git_remote", "origin") if config else "origin"
|
||||
|
||||
if should_modify:
|
||||
setup_info: SetupInfo = collect_setup_info()
|
||||
git_remote = setup_info.git_remote
|
||||
configure_pyproject_toml(setup_info)
|
||||
|
||||
install_github_app()
|
||||
install_github_app(git_remote)
|
||||
|
||||
install_github_actions(override_formatter_check=True)
|
||||
|
||||
|
|
@ -151,7 +155,23 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
|
|||
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
|
||||
|
||||
|
||||
def should_modify_pyproject_toml() -> bool:
|
||||
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
|
||||
if not pyproject_toml_path.exists():
|
||||
return None
|
||||
try:
|
||||
config, _ = parse_config_file(pyproject_toml_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
|
||||
return None
|
||||
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
|
||||
return None
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
|
||||
|
||||
If it does, ask the user if they want to re-configure it.
|
||||
|
|
@ -159,23 +179,16 @@ def should_modify_pyproject_toml() -> bool:
|
|||
from rich.prompt import Confirm
|
||||
|
||||
pyproject_toml_path = Path.cwd() / "pyproject.toml"
|
||||
if not pyproject_toml_path.exists():
|
||||
return True
|
||||
try:
|
||||
config, config_file_path = parse_config_file(pyproject_toml_path)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
|
||||
return True
|
||||
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
|
||||
return True
|
||||
config = is_valid_pyproject_toml(pyproject_toml_path)
|
||||
if config is None:
|
||||
return True, None
|
||||
|
||||
return Confirm.ask(
|
||||
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
|
||||
default=False,
|
||||
show_default=True,
|
||||
)
|
||||
), config
|
||||
|
||||
|
||||
# Custom theme for better UX
|
||||
|
|
@ -958,16 +971,23 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
click.echo()
|
||||
|
||||
|
||||
def install_github_app() -> None:
|
||||
def install_github_app(git_remote: str) -> None:
|
||||
try:
|
||||
git_repo = git.Repo(search_parent_directories=True)
|
||||
except git.InvalidGitRepositoryError:
|
||||
click.echo("Skipping GitHub app installation because you're not in a git repository.")
|
||||
return
|
||||
owner, repo = get_repo_owner_and_name(git_repo)
|
||||
|
||||
if git_remote not in get_git_remotes(git_repo):
|
||||
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
|
||||
return
|
||||
|
||||
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
|
||||
|
||||
if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):
|
||||
click.echo("🐙 Looks like you've already installed the Codeflash GitHub app on this repository! Continuing…")
|
||||
click.echo(
|
||||
f"🐙 Looks like you've already installed the Codeflash GitHub app on this repository ({owner}/{repo})! Continuing…"
|
||||
)
|
||||
|
||||
else:
|
||||
click.prompt(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
from contextlib import contextmanager
|
||||
from itertools import cycle
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
|
@ -19,15 +19,24 @@ from rich.progress import (
|
|||
|
||||
from codeflash.cli_cmds.console_constants import SPINNER_TYPES
|
||||
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.lsp.lsp_logger import enhanced_log
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from rich.progress import TaskID
|
||||
|
||||
from codeflash.lsp.lsp_message import LspMessage
|
||||
|
||||
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
|
||||
|
||||
console = Console()
|
||||
|
||||
if is_LSP_enabled():
|
||||
console.quiet = True
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],
|
||||
|
|
@ -37,6 +46,24 @@ logging.basicConfig(
|
|||
logger = logging.getLogger("rich")
|
||||
logging.getLogger("parso").setLevel(logging.WARNING)
|
||||
|
||||
# override the logger to reformat the messages for the lsp
|
||||
for level in ("info", "debug", "warning", "error"):
|
||||
real_fn = getattr(logger, level)
|
||||
setattr(
|
||||
logger,
|
||||
level,
|
||||
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
|
||||
msg, _real_fn, _level, *args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def lsp_log(message: LspMessage) -> None:
|
||||
if not is_LSP_enabled():
|
||||
return
|
||||
json_msg = message.serialize()
|
||||
logger.info(json_msg)
|
||||
|
||||
|
||||
def paneled_text(
|
||||
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
|
||||
|
|
@ -53,7 +80,10 @@ def paneled_text(
|
|||
console.print(panel)
|
||||
|
||||
|
||||
def code_print(code_str: str) -> None:
|
||||
def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
|
||||
return
|
||||
"""Print code with syntax highlighting."""
|
||||
from rich.syntax import Syntax
|
||||
|
||||
|
|
@ -74,6 +104,11 @@ def progress_bar(
|
|||
If revert_to_print is True, falls back to printing a single logger.info message
|
||||
instead of showing a progress bar.
|
||||
"""
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=message, takes_time=True))
|
||||
yield
|
||||
return
|
||||
|
||||
if revert_to_print:
|
||||
logger.info(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import argparse
|
||||
|
||||
|
|
@ -142,8 +144,11 @@ def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optiona
|
|||
if previous_checkpoint_functions and Confirm.ask(
|
||||
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
|
||||
default=True,
|
||||
console=console,
|
||||
):
|
||||
pass
|
||||
console.rule()
|
||||
else:
|
||||
previous_checkpoint_functions = None
|
||||
|
||||
console.rule()
|
||||
return previous_checkpoint_functions
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -119,6 +120,32 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def _find_insertion_index(self, updated_node: cst.Module) -> int:
|
||||
"""Find the position of the last import statement in the top-level of the module."""
|
||||
insert_index = 0
|
||||
for i, stmt in enumerate(updated_node.body):
|
||||
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
|
||||
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
|
||||
)
|
||||
|
||||
is_conditional_import = isinstance(stmt, cst.If) and all(
|
||||
isinstance(inner, cst.SimpleStatementLine)
|
||||
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
|
||||
for inner in stmt.body.body
|
||||
)
|
||||
|
||||
if is_top_level_import or is_conditional_import:
|
||||
insert_index = i + 1
|
||||
|
||||
# Stop scanning once we reach a class or function definition.
|
||||
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
|
||||
# Without this check, a stray import later in the file
|
||||
# would incorrectly shift our insertion index below actual code definitions.
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
break
|
||||
|
||||
return insert_index
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
|
@ -131,18 +158,26 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
]
|
||||
|
||||
if assignments_to_append:
|
||||
# Add a blank line before appending new assignments if needed
|
||||
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
|
||||
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
|
||||
new_statements.pop() # Remove the Pass statement but keep the empty line
|
||||
# after last top-level imports
|
||||
insert_index = self._find_insertion_index(updated_node)
|
||||
|
||||
# Add the new assignments
|
||||
new_statements.extend(
|
||||
[
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for assignment in assignments_to_append
|
||||
]
|
||||
)
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
for assignment in assignments_to_append
|
||||
]
|
||||
|
||||
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
|
||||
|
||||
# Add a blank line after the last assignment if needed
|
||||
after_index = insert_index + len(assignment_lines)
|
||||
if after_index < len(new_statements):
|
||||
next_stmt = new_statements[after_index]
|
||||
# If there's no empty line, add one
|
||||
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
|
||||
if not has_empty:
|
||||
new_statements[after_index] = next_stmt.with_changes(
|
||||
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
|
||||
)
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
|
@ -195,6 +230,79 @@ class LastImportFinder(cst.CSTVisitor):
|
|||
self.last_import_line = self.current_line
|
||||
|
||||
|
||||
class DottedImportCollector(cst.CSTVisitor):
|
||||
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
import os ==> "os"
|
||||
import dbt.adapters.factory ==> "dbt.adapters.factory"
|
||||
from pathlib import Path ==> "pathlib.Path"
|
||||
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
|
||||
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
|
||||
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.imports: set[str] = set()
|
||||
self.depth = 0 # top-level
|
||||
|
||||
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
|
||||
if isinstance(expr, cst.Name):
|
||||
return expr.value
|
||||
if isinstance(expr, cst.Attribute):
|
||||
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
|
||||
return ""
|
||||
|
||||
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
|
||||
for statement in block.body:
|
||||
if isinstance(statement, cst.SimpleStatementLine):
|
||||
for child in statement.body:
|
||||
if isinstance(child, cst.Import):
|
||||
for alias in child.names:
|
||||
module = self.get_full_dotted_name(alias.name)
|
||||
asname = alias.asname.name.value if alias.asname else alias.name.value
|
||||
if isinstance(asname, cst.Attribute):
|
||||
self.imports.add(module)
|
||||
else:
|
||||
self.imports.add(module if module == asname else f"{module}.{asname}")
|
||||
|
||||
elif isinstance(child, cst.ImportFrom):
|
||||
if child.module is None:
|
||||
continue
|
||||
module = self.get_full_dotted_name(child.module)
|
||||
for alias in child.names:
|
||||
if isinstance(alias, cst.ImportAlias):
|
||||
name = alias.name.value
|
||||
asname = alias.asname.name.value if alias.asname else name
|
||||
self.imports.add(f"{module}.{asname}")
|
||||
|
||||
def visit_Module(self, node: cst.Module) -> None:
|
||||
self.depth = 0
|
||||
self._collect_imports_from_block(node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
if self.depth == 0:
|
||||
self._collect_imports_from_block(node.body)
|
||||
|
||||
def visit_Try(self, node: cst.Try) -> None:
|
||||
if self.depth == 0:
|
||||
self._collect_imports_from_block(node.body)
|
||||
|
||||
|
||||
class ImportInserter(cst.CSTTransformer):
|
||||
"""Transformer that inserts global statements after the last import."""
|
||||
|
||||
|
|
@ -227,12 +335,12 @@ class ImportInserter(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
|
||||
def extract_global_statements(source_code: str) -> list[cst.SimpleStatementLine]:
|
||||
def extract_global_statements(source_code: str) -> tuple[cst.Module, list[cst.SimpleStatementLine]]:
|
||||
"""Extract global statements from source code."""
|
||||
module = cst.parse_module(source_code)
|
||||
collector = GlobalStatementCollector()
|
||||
module.visit(collector)
|
||||
return collector.global_statements
|
||||
return module, collector.global_statements
|
||||
|
||||
|
||||
def find_last_import_line(target_code: str) -> int:
|
||||
|
|
@ -265,30 +373,41 @@ def delete___future___aliased_imports(module_code: str) -> str:
|
|||
|
||||
|
||||
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
||||
non_assignment_global_statements = extract_global_statements(src_module_code)
|
||||
src_module, new_added_global_statements = extract_global_statements(src_module_code)
|
||||
dst_module, existing_global_statements = extract_global_statements(dst_module_code)
|
||||
|
||||
# Find the last import line in target
|
||||
last_import_line = find_last_import_line(dst_module_code)
|
||||
unique_global_statements = []
|
||||
for stmt in new_added_global_statements:
|
||||
if any(
|
||||
stmt is existing_stmt or stmt.deep_equals(existing_stmt) for existing_stmt in existing_global_statements
|
||||
):
|
||||
continue
|
||||
unique_global_statements.append(stmt)
|
||||
|
||||
# Parse the target code
|
||||
target_module = cst.parse_module(dst_module_code)
|
||||
|
||||
# Create transformer to insert non_assignment_global_statements
|
||||
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
|
||||
#
|
||||
# # Apply transformation
|
||||
modified_module = target_module.visit(transformer)
|
||||
dst_module_code = modified_module.code
|
||||
|
||||
# Parse the code
|
||||
original_module = cst.parse_module(dst_module_code)
|
||||
new_module = cst.parse_module(src_module_code)
|
||||
mod_dst_code = dst_module_code
|
||||
# Insert unique global statements if any
|
||||
if unique_global_statements:
|
||||
last_import_line = find_last_import_line(dst_module_code)
|
||||
# Reuse already-parsed dst_module
|
||||
transformer = ImportInserter(unique_global_statements, last_import_line)
|
||||
# Use visit inplace, don't parse again
|
||||
modified_module = dst_module.visit(transformer)
|
||||
mod_dst_code = modified_module.code
|
||||
# Parse the code after insertion
|
||||
original_module = cst.parse_module(mod_dst_code)
|
||||
else:
|
||||
# No new statements to insert, reuse already-parsed dst_module
|
||||
original_module = dst_module
|
||||
|
||||
# Parse the src_module_code once only (already done above: src_module)
|
||||
# Collect assignments from the new file
|
||||
new_collector = GlobalAssignmentCollector()
|
||||
new_module.visit(new_collector)
|
||||
src_module.visit(new_collector)
|
||||
# Only create transformer if there are assignments to insert/transform
|
||||
if not new_collector.assignments: # nothing to transform
|
||||
return mod_dst_code
|
||||
|
||||
# Transform the original file
|
||||
# Transform the original destination module
|
||||
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
|
||||
transformed_module = original_module.visit(transformer)
|
||||
|
||||
|
|
@ -329,9 +448,19 @@ def add_needed_imports_from_module(
|
|||
except Exception as e:
|
||||
logger.error(f"Error parsing source module code: {e}")
|
||||
return dst_module_code
|
||||
|
||||
dotted_import_collector = DottedImportCollector()
|
||||
try:
|
||||
parsed_dst_module = cst.parse_module(dst_module_code)
|
||||
parsed_dst_module.visit(dotted_import_collector)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_module_code # Return the original code if there's a syntax error
|
||||
|
||||
try:
|
||||
for mod in gatherer.module_imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod)
|
||||
if mod not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
|
||||
for mod, obj_seq in gatherer.object_mapping.items():
|
||||
for obj in obj_seq:
|
||||
|
|
@ -339,28 +468,29 @@ def add_needed_imports_from_module(
|
|||
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
|
||||
):
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
if f"{mod}.{obj}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
|
||||
for mod, asname in gatherer.module_aliases.items():
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
|
||||
if f"{mod}.{asname}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
|
||||
|
||||
for mod, alias_pairs in gatherer.alias_mapping.items():
|
||||
for alias_pair in alias_pairs:
|
||||
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
|
||||
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
|
||||
try:
|
||||
parsed_module = cst.parse_module(dst_module_code)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_module_code # Return the original code if there's a syntax error
|
||||
try:
|
||||
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
|
||||
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
|
||||
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
|
||||
return transformed_module.code.lstrip("\n")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
|
||||
|
||||
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
|
||||
|
||||
|
|
@ -408,16 +408,23 @@ def replace_functions_and_add_imports(
|
|||
|
||||
def replace_function_definitions_in_module(
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
|
||||
project_root_path: Path,
|
||||
should_add_global_assignments: bool = True, # noqa: FBT001, FBT002
|
||||
) -> bool:
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
add_global_assignments(optimized_code, source_code),
|
||||
# adding the global assignments before replacing the code, not after
|
||||
# becuase of an "edge case" where the optimized code intoduced a new import and a global assignment using that import
|
||||
# and that import wasn't used before, so it was ignored when calling AddImportsVisitor.add_needed_import inside replace_functions_and_add_imports (because the global assignment wasn't added yet)
|
||||
# this was added at https://github.com/codeflash-ai/codeflash/pull/448
|
||||
add_global_assignments(code_to_apply, source_code) if should_add_global_assignments else source_code,
|
||||
function_names,
|
||||
optimized_code,
|
||||
code_to_apply,
|
||||
module_abspath,
|
||||
preexisting_objects,
|
||||
project_root_path,
|
||||
|
|
@ -428,6 +435,19 @@ def replace_function_definitions_in_module(
|
|||
return True
|
||||
|
||||
|
||||
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
|
||||
file_to_code_context = optimized_code.file_to_path()
|
||||
module_optimized_code = file_to_code_context.get(str(relative_path))
|
||||
if module_optimized_code is None:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
return module_optimized_code
|
||||
|
||||
|
||||
def is_zero_diff(original_code: str, new_code: str) -> bool:
|
||||
return normalize_code(original_code) == normalize_code(new_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,23 @@ from codeflash.code_utils.config_parser import find_pyproject_toml
|
|||
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
|
||||
|
||||
|
||||
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
|
||||
"""Return the unified diff between two code strings as a single string.
|
||||
|
||||
:param code1: First code string (original).
|
||||
:param code2: Second code string (modified).
|
||||
:param fromfile: Label for the first code string.
|
||||
:param tofile: Label for the second code string.
|
||||
:return: Unified diff as a string.
|
||||
"""
|
||||
code1_lines = code1.splitlines(keepends=True)
|
||||
code2_lines = code2.splitlines(keepends=True)
|
||||
|
||||
diff = difflib.unified_diff(code1_lines, code2_lines, fromfile=fromfile, tofile=tofile, lineterm="")
|
||||
|
||||
return "".join(diff)
|
||||
|
||||
|
||||
def diff_length(a: str, b: str) -> int:
|
||||
"""Compute the length (in characters) of the unified diff between two strings.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,3 +11,25 @@ COVERAGE_THRESHOLD = 60.0
|
|||
MIN_TESTCASE_PASSED_THRESHOLD = 6
|
||||
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
|
||||
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
|
||||
N_CANDIDATES_LP = 6
|
||||
|
||||
# LSP-specific
|
||||
N_CANDIDATES_LSP = 3
|
||||
N_TESTS_TO_GENERATE_LSP = 2
|
||||
TOTAL_LOOPING_TIME_LSP = 10.0 # Kept same timing for LSP mode to avoid in increase in performance reporting
|
||||
N_CANDIDATES_LP_LSP = 3
|
||||
|
||||
MAX_N_CANDIDATES = 5
|
||||
MAX_N_CANDIDATES_LP = 6
|
||||
|
||||
try:
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
_IS_LSP_ENABLED = is_LSP_enabled()
|
||||
except ImportError:
|
||||
_IS_LSP_ENABLED = False
|
||||
|
||||
N_CANDIDATES_EFFECTIVE = min(N_CANDIDATES_LSP if _IS_LSP_ENABLED else N_CANDIDATES, MAX_N_CANDIDATES)
|
||||
N_CANDIDATES_LP_EFFECTIVE = min(N_CANDIDATES_LP_LSP if _IS_LSP_ENABLED else N_CANDIDATES_LP, MAX_N_CANDIDATES_LP)
|
||||
N_TESTS_TO_GENERATE_EFFECTIVE = N_TESTS_TO_GENERATE_LSP if _IS_LSP_ENABLED else N_TESTS_TO_GENERATE
|
||||
TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME_LSP if _IS_LSP_ENABLED else TOTAL_LOOPING_TIME
|
||||
|
|
|
|||
|
|
@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
|
|||
return full_name
|
||||
|
||||
|
||||
def generate_candidates(source_code_path: Path) -> list[str]:
|
||||
def generate_candidates(source_code_path: Path) -> set[str]:
|
||||
"""Generate all the possible candidates for coverage data based on the source code path."""
|
||||
candidates = [source_code_path.name]
|
||||
candidates = set()
|
||||
candidates.add(source_code_path.name)
|
||||
current_path = source_code_path.parent
|
||||
|
||||
last_added = source_code_path.name
|
||||
while current_path != current_path.parent:
|
||||
candidate_path = (Path(current_path.name) / candidates[-1]).as_posix()
|
||||
candidates.append(candidate_path)
|
||||
candidate_path = str(Path(current_path.name) / last_added)
|
||||
candidates.add(candidate_path)
|
||||
last_added = candidate_path
|
||||
current_path = current_path.parent
|
||||
|
||||
candidates.add(str(source_code_path))
|
||||
return candidates
|
||||
|
||||
|
||||
|
|
|
|||
250
codeflash/code_utils/deduplicate_code.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
import ast
|
||||
import hashlib
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
||||
Preserves function names, class names, parameters, built-ins, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.scope_stack = []
|
||||
self.builtins = set(dir(__builtins__))
|
||||
self.imports: set[str] = set()
|
||||
self.global_vars: set[str] = set()
|
||||
self.nonlocal_vars: set[str] = set()
|
||||
self.parameters: set[str] = set() # Track function parameters
|
||||
|
||||
def enter_scope(self): # noqa : ANN201
|
||||
"""Enter a new scope (function/class)."""
|
||||
self.scope_stack.append(
|
||||
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
|
||||
)
|
||||
|
||||
def exit_scope(self): # noqa : ANN201
|
||||
"""Exit current scope and restore parent scope."""
|
||||
if self.scope_stack:
|
||||
scope = self.scope_stack.pop()
|
||||
self.var_mapping = scope["var_mapping"]
|
||||
self.var_counter = scope["var_counter"]
|
||||
self.parameters = scope["parameters"]
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
# Don't normalize if it's a builtin, import, global, nonlocal, or parameter
|
||||
if (
|
||||
name in self.builtins
|
||||
or name in self.imports
|
||||
or name in self.global_vars
|
||||
or name in self.nonlocal_vars
|
||||
or name in self.parameters
|
||||
):
|
||||
return name
|
||||
|
||||
# Only normalize local variables
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def visit_Import(self, node): # noqa : ANN001, ANN201
|
||||
"""Track imported names."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name.split(".")[0])
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node): # noqa : ANN001, ANN201
|
||||
"""Track imported names from modules."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name)
|
||||
return node
|
||||
|
||||
def visit_Global(self, node): # noqa : ANN001, ANN201
|
||||
"""Track global variable declarations."""
|
||||
# Avoid repeated .add calls by using set.update with list
|
||||
self.global_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node): # noqa : ANN001, ANN201
|
||||
"""Track nonlocal variable declarations."""
|
||||
# Using set.update for batch insertion (faster than add-in-loop)
|
||||
self.nonlocal_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node): # noqa : ANN001, ANN201
|
||||
"""Process function but keep function name and parameters unchanged."""
|
||||
self.enter_scope()
|
||||
|
||||
# Track all parameters (don't modify them)
|
||||
for arg in node.args.args:
|
||||
self.parameters.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.parameters.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.parameters.add(node.args.kwarg.arg)
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.parameters.add(arg.arg)
|
||||
|
||||
# Visit function body
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node): # noqa : ANN001, ANN201
|
||||
"""Handle async functions same as regular functions."""
|
||||
return self.visit_FunctionDef(node)
|
||||
|
||||
def visit_ClassDef(self, node): # noqa : ANN001, ANN201
|
||||
"""Process class but keep class name unchanged."""
|
||||
self.enter_scope()
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_Name(self, node): # noqa : ANN001, ANN201
|
||||
"""Normalize variable names in Name nodes."""
|
||||
if isinstance(node.ctx, (ast.Store, ast.Del)):
|
||||
# For assignments and deletions, check if we should normalize
|
||||
if (
|
||||
node.id not in self.builtins
|
||||
and node.id not in self.imports
|
||||
and node.id not in self.parameters
|
||||
and node.id not in self.global_vars
|
||||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load): # noqa : SIM102
|
||||
# For loading, use existing mapping if available
|
||||
if node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node): # noqa : ANN001, ANN201
|
||||
"""Normalize exception variable names."""
|
||||
if node.name:
|
||||
node.name = self.get_normalized_name(node.name)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node): # noqa : ANN001, ANN201
|
||||
"""Normalize comprehension target variables."""
|
||||
# Create new scope for comprehension
|
||||
old_mapping = dict(self.var_mapping)
|
||||
old_counter = self.var_counter
|
||||
|
||||
# Process the comprehension
|
||||
node = self.generic_visit(node)
|
||||
|
||||
# Restore scope
|
||||
self.var_mapping = old_mapping
|
||||
self.var_counter = old_counter
|
||||
return node
|
||||
|
||||
def visit_For(self, node): # noqa : ANN001, ANN201
|
||||
"""Handle for loop target variables."""
|
||||
# The target in a for loop is a local variable that should be normalized
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node): # noqa : ANN001, ANN201
|
||||
"""Handle with statement as variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def normalize_code(code: str, remove_docstrings: bool = True, return_ast_dump: bool = False) -> str: # noqa : FBT002, FBT001
|
||||
"""Normalize Python code by parsing, cleaning, and normalizing only variable names.
|
||||
|
||||
Function names, class names, and parameters are preserved.
|
||||
|
||||
Args:
|
||||
code: Python source code as string
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
return_ast_dump: return_ast_dump
|
||||
|
||||
Returns:
|
||||
Normalized code as string
|
||||
|
||||
"""
|
||||
try:
|
||||
# Parse the code
|
||||
tree = ast.parse(code)
|
||||
|
||||
# Remove docstrings if requested
|
||||
if remove_docstrings:
|
||||
remove_docstrings_from_ast(tree)
|
||||
|
||||
# Normalize variable names
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
if return_ast_dump:
|
||||
# This is faster than unparsing etc
|
||||
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
|
||||
|
||||
# Fix missing locations in the AST
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
# Unparse back to code
|
||||
return ast.unparse(normalized_tree)
|
||||
except SyntaxError as e:
|
||||
msg = f"Invalid Python syntax: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
|
||||
def remove_docstrings_from_ast(node): # noqa : ANN001, ANN201
|
||||
"""Remove docstrings from AST nodes."""
|
||||
# Only FunctionDef, AsyncFunctionDef, ClassDef, and Module can contain docstrings in their body[0]
|
||||
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
|
||||
# Use our own stack-based DFS instead of ast.walk for efficiency
|
||||
stack = [node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
if isinstance(current_node, node_types):
|
||||
# Remove docstring if it's the first stmt in body
|
||||
body = current_node.body
|
||||
if (
|
||||
body
|
||||
and isinstance(body[0], ast.Expr)
|
||||
and isinstance(body[0].value, ast.Constant)
|
||||
and isinstance(body[0].value.value, str)
|
||||
):
|
||||
current_node.body = body[1:]
|
||||
# Only these nodes can nest more docstring-containing nodes
|
||||
# Add their body elements to stack, avoiding unnecessary traversal
|
||||
stack.extend([child for child in body if isinstance(child, node_types)])
|
||||
|
||||
|
||||
def get_code_fingerprint(code: str) -> str:
|
||||
"""Generate a fingerprint for normalized code.
|
||||
|
||||
Args:
|
||||
code: Python source code
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
normalized = normalize_code(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
||||
|
||||
def are_codes_duplicate(code1: str, code2: str) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical (ignoring local variable names)
|
||||
|
||||
"""
|
||||
try:
|
||||
normalized1 = normalize_code(code1, return_ast_dump=True)
|
||||
normalized2 = normalize_code(code2, return_ast_dump=True)
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return normalized1 == normalized2
|
||||
|
|
@ -7,10 +7,10 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import exit_with_message
|
||||
from codeflash.code_utils.formatter import format_code
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
|
||||
|
||||
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
|
||||
|
|
@ -33,7 +33,23 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
|
|||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_codeflash_api_key() -> str:
|
||||
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
|
||||
# Check environment variable first
|
||||
env_api_key = os.environ.get("CODEFLASH_API_KEY")
|
||||
shell_api_key = read_api_key_from_shell_config()
|
||||
|
||||
# If we have an env var but it's not in shell config, save it for persistence
|
||||
if env_api_key and not shell_api_key:
|
||||
try:
|
||||
from codeflash.either import is_successful
|
||||
|
||||
result = save_api_key_to_rc(env_api_key)
|
||||
if is_successful(result):
|
||||
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to automatically save API key to shell config: {e}")
|
||||
|
||||
api_key = env_api_key or shell_api_key
|
||||
|
||||
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
|
||||
if not api_key:
|
||||
msg = (
|
||||
|
|
@ -119,11 +135,6 @@ def is_ci() -> bool:
|
|||
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_LSP_enabled() -> bool:
|
||||
return console.quiet
|
||||
|
||||
|
||||
def is_pr_draft() -> bool:
|
||||
"""Check if the PR is draft. in the github action context."""
|
||||
event = get_cached_gh_event_data()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import Optional, Union
|
|||
import isort
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
|
||||
|
|
@ -44,9 +45,7 @@ def apply_formatter_cmds(
|
|||
test_dir_str: Optional[str],
|
||||
print_status: bool, # noqa
|
||||
exit_on_failure: bool = True, # noqa
|
||||
) -> tuple[Path, str]:
|
||||
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
|
||||
formatter_name = cmds[0].lower()
|
||||
) -> tuple[Path, str, bool]:
|
||||
should_make_copy = False
|
||||
file_path = path
|
||||
|
||||
|
|
@ -54,9 +53,6 @@ def apply_formatter_cmds(
|
|||
should_make_copy = True
|
||||
file_path = Path(test_dir_str) / "temp.py"
|
||||
|
||||
if not cmds or formatter_name == "disabled":
|
||||
return path, path.read_text(encoding="utf8")
|
||||
|
||||
if not path.exists():
|
||||
msg = f"File {path} does not exist. Cannot apply formatter commands."
|
||||
raise FileNotFoundError(msg)
|
||||
|
|
@ -66,6 +62,7 @@ def apply_formatter_cmds(
|
|||
|
||||
file_token = "$file" # noqa: S105
|
||||
|
||||
changed = False
|
||||
for command in cmds:
|
||||
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
|
||||
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
|
||||
|
|
@ -74,6 +71,7 @@ def apply_formatter_cmds(
|
|||
if result.returncode == 0:
|
||||
if print_status:
|
||||
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
|
||||
changed = True
|
||||
else:
|
||||
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
|
||||
except FileNotFoundError as e:
|
||||
|
|
@ -88,7 +86,7 @@ def apply_formatter_cmds(
|
|||
if exit_on_failure:
|
||||
raise e from None
|
||||
|
||||
return file_path, file_path.read_text(encoding="utf8")
|
||||
return file_path, file_path.read_text(encoding="utf8"), changed
|
||||
|
||||
|
||||
def get_diff_lines_count(diff_output: str) -> int:
|
||||
|
|
@ -104,32 +102,43 @@ def get_diff_lines_count(diff_output: str) -> int:
|
|||
def format_code(
|
||||
formatter_cmds: list[str],
|
||||
path: Union[str, Path],
|
||||
optimized_function: str = "",
|
||||
optimized_code: str = "",
|
||||
check_diff: bool = False, # noqa
|
||||
print_status: bool = True, # noqa
|
||||
exit_on_failure: bool = True, # noqa
|
||||
) -> str:
|
||||
if console.quiet:
|
||||
# lsp mode
|
||||
if is_LSP_enabled():
|
||||
exit_on_failure = False
|
||||
with tempfile.TemporaryDirectory() as test_dir_str:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
|
||||
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
|
||||
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
|
||||
if formatter_name == "disabled":
|
||||
return path.read_text(encoding="utf8")
|
||||
|
||||
with tempfile.TemporaryDirectory() as test_dir_str:
|
||||
original_code = path.read_text(encoding="utf8")
|
||||
original_code_lines = len(original_code.split("\n"))
|
||||
|
||||
if check_diff and original_code_lines > 50:
|
||||
# we dont' count the formatting diff for the optimized function as it should be well-formatted
|
||||
original_code_without_opfunc = original_code.replace(optimized_function, "")
|
||||
original_code_without_opfunc = original_code.replace(optimized_code, "")
|
||||
|
||||
original_temp = Path(test_dir_str) / "original_temp.py"
|
||||
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
|
||||
|
||||
formatted_temp, formatted_code = apply_formatter_cmds(
|
||||
formatter_cmds, original_temp, test_dir_str, print_status=False
|
||||
formatted_temp, formatted_code, changed = apply_formatter_cmds(
|
||||
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
|
||||
)
|
||||
|
||||
if not changed:
|
||||
logger.warning(
|
||||
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
|
||||
)
|
||||
return original_code
|
||||
|
||||
diff_output = generate_unified_diff(
|
||||
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
|
||||
)
|
||||
|
|
@ -137,15 +146,22 @@ def format_code(
|
|||
|
||||
max_diff_lines = min(int(original_code_lines * 0.3), 50)
|
||||
|
||||
if diff_lines_count > max_diff_lines and max_diff_lines != -1:
|
||||
logger.debug(
|
||||
if diff_lines_count > max_diff_lines:
|
||||
logger.warning(
|
||||
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
|
||||
)
|
||||
return original_code
|
||||
|
||||
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
|
||||
_, formatted_code = apply_formatter_cmds(
|
||||
_, formatted_code, changed = apply_formatter_cmds(
|
||||
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
|
||||
)
|
||||
if not changed:
|
||||
logger.warning(
|
||||
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
|
||||
)
|
||||
return original_code
|
||||
|
||||
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
|
||||
return formatted_code
|
||||
|
||||
|
|
|
|||
|
|
@ -9,23 +9,31 @@ import time
|
|||
from functools import cache
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import git
|
||||
from rich.prompt import Confirm
|
||||
from unidiff import PatchSet
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.config_consts import N_CANDIDATES
|
||||
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git import Repo
|
||||
|
||||
|
||||
def get_git_diff(repo_directory: Path = Path.cwd(), uncommitted_changes: bool = False) -> dict[str, list[int]]: # noqa: B008, FBT001, FBT002
|
||||
def get_git_diff(
|
||||
repo_directory: Path | None = None, *, only_this_commit: Optional[str] = None, uncommitted_changes: bool = False
|
||||
) -> dict[str, list[int]]:
|
||||
if repo_directory is None:
|
||||
repo_directory = Path.cwd()
|
||||
repository = git.Repo(repo_directory, search_parent_directories=True)
|
||||
commit = repository.head.commit
|
||||
if uncommitted_changes:
|
||||
if only_this_commit:
|
||||
uni_diff_text = repository.git.diff(
|
||||
only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True
|
||||
)
|
||||
elif uncommitted_changes:
|
||||
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
|
||||
else:
|
||||
uni_diff_text = repository.git.diff(
|
||||
|
|
@ -117,30 +125,31 @@ def confirm_proceeding_with_no_git_repo() -> str | bool:
|
|||
return True
|
||||
|
||||
|
||||
def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002
|
||||
current_branch = repo.active_branch.name
|
||||
def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", *, wait_for_push: bool = False) -> bool:
|
||||
current_branch = repo.active_branch
|
||||
current_branch_name = current_branch.name
|
||||
remote = repo.remote(name=git_remote)
|
||||
|
||||
# Check if the branch is pushed
|
||||
if f"{git_remote}/{current_branch}" not in repo.refs:
|
||||
logger.warning(f"⚠️ The branch '{current_branch}' is not pushed to the remote repository.")
|
||||
if f"{git_remote}/{current_branch_name}" not in repo.refs:
|
||||
logger.warning(f"⚠️ The branch '{current_branch_name}' is not pushed to the remote repository.")
|
||||
if not sys.__stdin__.isatty():
|
||||
logger.warning("Non-interactive shell detected. Branch will not be pushed.")
|
||||
return False
|
||||
if sys.__stdin__.isatty() and Confirm.ask(
|
||||
f"⚡️ In order for me to create PRs, your current branch needs to be pushed. Do you want to push "
|
||||
f"the branch '{current_branch}' to the remote repository?",
|
||||
f"the branch '{current_branch_name}' to the remote repository?",
|
||||
default=False,
|
||||
):
|
||||
remote.push(current_branch)
|
||||
logger.info(f"⬆️ Branch '{current_branch}' has been pushed to {git_remote}.")
|
||||
logger.info(f"⬆️ Branch '{current_branch_name}' has been pushed to {git_remote}.")
|
||||
if wait_for_push:
|
||||
time.sleep(3) # adding this to give time for the push to register with GitHub,
|
||||
# so that our modifications to it are not rejected
|
||||
return True
|
||||
logger.info(f"🔘 Branch '{current_branch}' has not been pushed to {git_remote}.")
|
||||
logger.info(f"🔘 Branch '{current_branch_name}' has not been pushed to {git_remote}.")
|
||||
return False
|
||||
logger.debug(f"The branch '{current_branch}' is present in the remote repository.")
|
||||
logger.debug(f"The branch '{current_branch_name}' is present in the remote repository.")
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -155,7 +164,7 @@ def create_git_worktrees(
|
|||
) -> tuple[Path | None, list[Path]]:
|
||||
if git_root and worktree_root_dir:
|
||||
worktree_root = Path(tempfile.mkdtemp(dir=worktree_root_dir))
|
||||
worktrees = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(N_CANDIDATES + 1)]
|
||||
worktrees = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(N_CANDIDATES_EFFECTIVE + 1)]
|
||||
for worktree in worktrees:
|
||||
subprocess.run(["git", "worktree", "add", "-d", worktree], cwd=module_root, check=True)
|
||||
else:
|
||||
|
|
|
|||
170
codeflash/code_utils/git_worktree_utils.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import git
|
||||
from filelock import FileLock
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.compat import codeflash_cache_dir
|
||||
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from git import Repo
|
||||
|
||||
|
||||
worktree_dirs = codeflash_cache_dir / "worktrees"
|
||||
patches_dir = codeflash_cache_dir / "patches"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git import Repo
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_git_project_id() -> str:
|
||||
"""Return the first commit sha of the repo."""
|
||||
repo: Repo = git.Repo(search_parent_directories=True)
|
||||
root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0))
|
||||
return root_commits[0].hexsha
|
||||
|
||||
|
||||
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
|
||||
repository = git.Repo(worktree_dir, search_parent_directories=True)
|
||||
repository.git.add(".")
|
||||
repository.git.commit("-m", commit_message, "--no-verify")
|
||||
|
||||
|
||||
def create_detached_worktree(module_root: Path) -> Optional[Path]:
|
||||
if not check_running_in_git_repo(module_root):
|
||||
logger.warning("Module is not in a git repository. Skipping worktree creation.")
|
||||
return None
|
||||
git_root = git_root_dir()
|
||||
current_time_str = time.strftime("%Y%m%d-%H%M%S")
|
||||
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
|
||||
|
||||
repository = git.Repo(git_root, search_parent_directories=True)
|
||||
|
||||
repository.git.worktree("add", "-d", str(worktree_dir))
|
||||
|
||||
# Get uncommitted diff from the original repo
|
||||
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
|
||||
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
|
||||
uni_diff_text = repository.git.diff(
|
||||
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
|
||||
)
|
||||
|
||||
if not uni_diff_text.strip():
|
||||
logger.info("!lsp|No uncommitted changes to copy to worktree.")
|
||||
return worktree_dir
|
||||
|
||||
# Write the diff to a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
|
||||
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
|
||||
tmp_patch_file.flush()
|
||||
|
||||
patch_path = Path(tmp_patch_file.name).resolve()
|
||||
|
||||
# Apply the patch inside the worktree
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
|
||||
cwd=worktree_dir,
|
||||
check=True,
|
||||
)
|
||||
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to apply patch to worktree: {e}")
|
||||
|
||||
return worktree_dir
|
||||
|
||||
|
||||
def remove_worktree(worktree_dir: Path) -> None:
|
||||
try:
|
||||
repository = git.Repo(worktree_dir, search_parent_directories=True)
|
||||
repository.git.worktree("remove", "--force", worktree_dir)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to remove worktree: {worktree_dir}")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_patches_dir_for_project() -> Path:
|
||||
project_id = get_git_project_id() or ""
|
||||
return Path(patches_dir / project_id)
|
||||
|
||||
|
||||
def get_patches_metadata() -> dict[str, Any]:
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
meta_file = project_patches_dir / "metadata.json"
|
||||
if meta_file.exists():
|
||||
with meta_file.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
return {"id": get_git_project_id() or "", "patches": []}
|
||||
|
||||
|
||||
def save_patches_metadata(patch_metadata: dict) -> dict:
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
meta_file = project_patches_dir / "metadata.json"
|
||||
lock_file = project_patches_dir / "metadata.json.lock"
|
||||
|
||||
# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
|
||||
with FileLock(lock_file, timeout=10):
|
||||
metadata = get_patches_metadata()
|
||||
|
||||
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
|
||||
metadata["patches"].append(patch_metadata)
|
||||
|
||||
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
return patch_metadata
|
||||
|
||||
|
||||
def overwrite_patch_metadata(patches: list[dict]) -> bool:
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
meta_file = project_patches_dir / "metadata.json"
|
||||
lock_file = project_patches_dir / "metadata.json.lock"
|
||||
|
||||
with FileLock(lock_file, timeout=10):
|
||||
metadata = get_patches_metadata()
|
||||
metadata["patches"] = patches
|
||||
meta_file.write_text(json.dumps(metadata, indent=2))
|
||||
return True
|
||||
|
||||
|
||||
def create_diff_patch_from_worktree(
|
||||
worktree_dir: Path,
|
||||
files: list[str],
|
||||
fto_name: Optional[str] = None,
|
||||
metadata_input: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
repository = git.Repo(worktree_dir, search_parent_directories=True)
|
||||
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
|
||||
|
||||
if not uni_diff_text:
|
||||
logger.warning("No changes found in worktree.")
|
||||
return {}
|
||||
|
||||
if not uni_diff_text.endswith("\n"):
|
||||
uni_diff_text += "\n"
|
||||
|
||||
project_patches_dir = get_patches_dir_for_project()
|
||||
project_patches_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
|
||||
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
|
||||
with patch_path.open("w", encoding="utf8") as f:
|
||||
f.write(uni_diff_text)
|
||||
|
||||
final_metadata = {"patch_path": str(patch_path)}
|
||||
if metadata_input:
|
||||
final_metadata.update(metadata_input)
|
||||
final_metadata = save_patches_metadata(final_metadata)
|
||||
|
||||
return final_metadata
|
||||
|
|
@ -367,15 +367,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
targets=[ast.Name(id="test_id", ctx=ast.Store())],
|
||||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1),
|
||||
]
|
||||
),
|
||||
lineno=lineno + 1,
|
||||
|
|
@ -455,7 +455,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
|
||||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value="_"),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
|
||||
]
|
||||
|
|
@ -468,13 +468,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
|
||||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.IfExp(
|
||||
test=ast.Name(id="test_class_name", ctx=ast.Load()),
|
||||
test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
|
||||
body=ast.BinOp(
|
||||
left=ast.Name(id="test_class_name", ctx=ast.Load()),
|
||||
left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
|
||||
op=ast.Add(),
|
||||
right=ast.Constant(value="."),
|
||||
),
|
||||
|
|
@ -482,11 +484,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
|
||||
]
|
||||
|
|
@ -539,7 +545,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
ast.Assign(
|
||||
targets=[ast.Name(id="return_value", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="wrapped", ctx=ast.Load()),
|
||||
func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()),
|
||||
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
|
||||
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
|
||||
),
|
||||
|
|
@ -666,11 +672,11 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
|
||||
ast.Tuple(
|
||||
elts=[
|
||||
ast.Name(id="test_module_name", ctx=ast.Load()),
|
||||
ast.Name(id="test_class_name", ctx=ast.Load()),
|
||||
ast.Name(id="test_name", ctx=ast.Load()),
|
||||
ast.Name(id="function_name", ctx=ast.Load()),
|
||||
ast.Name(id="loop_index", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_test_module_name", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_test_name", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_function_name", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
|
||||
ast.Name(id="invocation_id", ctx=ast.Load()),
|
||||
ast.Name(id="codeflash_duration", ctx=ast.Load()),
|
||||
ast.Name(id="pickled_return_value", ctx=ast.Load()),
|
||||
|
|
@ -709,13 +715,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
|
|||
name="codeflash_wrap",
|
||||
args=ast.arguments(
|
||||
args=[
|
||||
ast.arg(arg="wrapped", annotation=None),
|
||||
ast.arg(arg="test_module_name", annotation=None),
|
||||
ast.arg(arg="test_class_name", annotation=None),
|
||||
ast.arg(arg="test_name", annotation=None),
|
||||
ast.arg(arg="function_name", annotation=None),
|
||||
ast.arg(arg="line_id", annotation=None),
|
||||
ast.arg(arg="loop_index", annotation=None),
|
||||
ast.arg(arg="codeflash_wrapped", annotation=None),
|
||||
ast.arg(arg="codeflash_test_module_name", annotation=None),
|
||||
ast.arg(arg="codeflash_test_class_name", annotation=None),
|
||||
ast.arg(arg="codeflash_test_name", annotation=None),
|
||||
ast.arg(arg="codeflash_function_name", annotation=None),
|
||||
ast.arg(arg="codeflash_line_id", annotation=None),
|
||||
ast.arg(arg="codeflash_loop_index", annotation=None),
|
||||
*([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
|
||||
*([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ if os.name == "nt": # Windows
|
|||
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
|
||||
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
|
||||
else:
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=[\'"]?(cf-[^\s"]+)[\'"]$', re.MULTILINE)
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(
|
||||
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
|
||||
)
|
||||
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
|
||||
|
||||
|
||||
|
|
@ -42,7 +44,7 @@ def get_shell_rc_path() -> Path:
|
|||
|
||||
|
||||
def get_api_key_export_line(api_key: str) -> str:
|
||||
return f"{SHELL_RC_EXPORT_PREFIX}{api_key}"
|
||||
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'
|
||||
|
||||
|
||||
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
|
||||
|
|
|
|||
77
codeflash/code_utils/version_check.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
"""Version checking utilities for codeflash."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.version import __version__
|
||||
|
||||
# Simple cache to avoid checking too frequently
|
||||
_version_cache = {"version": "0.0.0", "timestamp": float(0)}
|
||||
_cache_duration = 3600 # 1 hour cache
|
||||
|
||||
|
||||
def get_latest_version_from_pypi() -> str | None:
|
||||
"""Get the latest version of codeflash from PyPI.
|
||||
|
||||
Returns:
|
||||
The latest version string from PyPI, or None if the request fails.
|
||||
|
||||
"""
|
||||
# Check cache first
|
||||
current_time = time.time()
|
||||
if _version_cache["version"] is not None and current_time - _version_cache["timestamp"] < _cache_duration:
|
||||
return _version_cache["version"]
|
||||
|
||||
try:
|
||||
response = requests.get("https://pypi.org/pypi/codeflash/json", timeout=2)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
latest_version = data["info"]["version"]
|
||||
|
||||
# Update cache
|
||||
_version_cache["version"] = latest_version
|
||||
_version_cache["timestamp"] = current_time
|
||||
|
||||
return latest_version
|
||||
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
|
||||
return None # noqa: TRY300
|
||||
except requests.RequestException as e:
|
||||
logger.debug(f"Network error fetching version from PyPI: {e}")
|
||||
return None
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.debug(f"Invalid response format from PyPI: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error fetching version from PyPI: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def check_for_newer_minor_version() -> None:
|
||||
"""Check if a newer minor version is available on PyPI and notify the user.
|
||||
|
||||
This function compares the current version with the latest version on PyPI.
|
||||
If a newer minor version is available, it prints an informational message
|
||||
suggesting the user upgrade.
|
||||
"""
|
||||
latest_version = get_latest_version_from_pypi()
|
||||
|
||||
if not latest_version:
|
||||
return
|
||||
|
||||
try:
|
||||
current_parsed = version.parse(__version__)
|
||||
latest_parsed = version.parse(latest_version)
|
||||
|
||||
# Check if there's a newer minor version available
|
||||
# We only notify for minor version updates, not patch updates
|
||||
if latest_parsed > current_parsed: # < > == operators can be directly applied on version objects
|
||||
logger.warning(f"A newer version({latest_version}) of Codeflash is available, please update soon!")
|
||||
|
||||
except version.InvalidVersion as e:
|
||||
logger.debug(f"Invalid version format: {e}")
|
||||
return
|
||||
|
|
@ -61,13 +61,14 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
# Extract code context for optimization
|
||||
final_read_writable_code = extract_code_string_context_from_files(
|
||||
final_read_writable_code = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
{},
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.READ_WRITABLE,
|
||||
).code
|
||||
)
|
||||
|
||||
read_only_code_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
|
|
@ -84,14 +85,14 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
# Handle token limits
|
||||
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
|
||||
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown)
|
||||
if final_read_writable_tokens > optim_token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
# Setup preexisting objects for code replacer
|
||||
preexisting_objects = set(
|
||||
chain(
|
||||
find_preexisting_objects(final_read_writable_code),
|
||||
*(find_preexisting_objects(codestring.code) for codestring in final_read_writable_code.code_strings),
|
||||
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,15 @@ from __future__ import annotations
|
|||
import ast
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -530,10 +529,15 @@ def revert_unused_helper_functions(
|
|||
helper_names = [helper.qualified_name for helper in helpers_in_file]
|
||||
reverted_code = replace_function_definitions_in_module(
|
||||
function_names=helper_names,
|
||||
optimized_code=original_code, # Use original code as the "optimized" code to revert
|
||||
optimized_code=CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root))
|
||||
]
|
||||
), # Use original code as the "optimized" code to revert
|
||||
module_abspath=file_path,
|
||||
preexisting_objects=set(), # Empty set since we're reverting
|
||||
project_root_path=project_root,
|
||||
should_add_global_assignments=False, # since we revert helpers functions after applying the optimization, we know that the file already has global assignments added, otherwise they would be added twice.
|
||||
)
|
||||
|
||||
if reverted_code:
|
||||
|
|
@ -608,8 +612,38 @@ def _analyze_imports_in_optimized_code(
|
|||
return dict(imported_names_map)
|
||||
|
||||
|
||||
def find_target_node(
|
||||
root: ast.AST, function_to_optimize: FunctionToOptimize
|
||||
) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]:
|
||||
parents = function_to_optimize.parents
|
||||
node = root
|
||||
for parent in parents:
|
||||
# Fast loop: directly look for the matching ClassDef in node.body
|
||||
body = getattr(node, "body", None)
|
||||
if not body:
|
||||
return None
|
||||
for child in body:
|
||||
if isinstance(child, ast.ClassDef) and child.name == parent.name:
|
||||
node = child
|
||||
break
|
||||
else:
|
||||
return None
|
||||
|
||||
# Now node is either the root or the target parent class; look for function
|
||||
body = getattr(node, "body", None)
|
||||
if not body:
|
||||
return None
|
||||
target_name = function_to_optimize.function_name
|
||||
for child in body:
|
||||
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name:
|
||||
return child
|
||||
return None
|
||||
|
||||
|
||||
def detect_unused_helper_functions(
|
||||
function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
code_context: CodeOptimizationContext,
|
||||
optimized_code: str | CodeStringsMarkdown,
|
||||
) -> list[FunctionSource]:
|
||||
"""Detect helper functions that are no longer called by the optimized entrypoint function.
|
||||
|
||||
|
|
@ -622,16 +656,20 @@ def detect_unused_helper_functions(
|
|||
List of FunctionSource objects representing unused helper functions
|
||||
|
||||
"""
|
||||
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
|
||||
return list(
|
||||
chain.from_iterable(
|
||||
detect_unused_helper_functions(function_to_optimize, code_context, code.code)
|
||||
for code in optimized_code.code_strings
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse the optimized code to analyze function calls and imports
|
||||
optimized_ast = ast.parse(optimized_code)
|
||||
|
||||
# Find the optimized entrypoint function
|
||||
entrypoint_function_ast = None
|
||||
for node in ast.walk(optimized_ast):
|
||||
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
|
||||
entrypoint_function_ast = node
|
||||
break
|
||||
entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize)
|
||||
|
||||
if not entrypoint_function_ast:
|
||||
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||
import git
|
||||
import libcst as cst
|
||||
from pydantic.dataclasses import dataclass
|
||||
from rich.tree import Tree
|
||||
|
||||
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
|
||||
|
|
@ -26,6 +27,7 @@ from codeflash.code_utils.env_utils import get_pr_number
|
|||
from codeflash.code_utils.git_utils import get_git_diff, get_repo_owner_and_name
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
||||
|
|
@ -37,6 +39,7 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -166,10 +169,11 @@ def get_functions_to_optimize(
|
|||
)
|
||||
functions: dict[str, list[FunctionToOptimize]]
|
||||
trace_file_path: Path | None = None
|
||||
is_lsp = is_LSP_enabled()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter(action="ignore", category=SyntaxWarning)
|
||||
if optimize_all:
|
||||
logger.info("Finding all functions in the module '%s'…", optimize_all)
|
||||
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
|
||||
console.rule()
|
||||
functions = get_all_files_and_functions(Path(optimize_all))
|
||||
elif replay_test:
|
||||
|
|
@ -177,12 +181,14 @@ def get_functions_to_optimize(
|
|||
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
|
||||
)
|
||||
elif file is not None:
|
||||
logger.info("Finding all functions in the file '%s'…", file)
|
||||
logger.info("!lsp|Finding all functions in the file '%s'…", file)
|
||||
console.rule()
|
||||
functions = find_all_functions_in_file(file)
|
||||
if only_get_this_function is not None:
|
||||
split_function = only_get_this_function.split(".")
|
||||
if len(split_function) > 2:
|
||||
if is_lsp:
|
||||
return functions, 0, None
|
||||
exit_with_message(
|
||||
"Function name should be in the format 'function_name' or 'class_name.function_name'"
|
||||
)
|
||||
|
|
@ -198,6 +204,8 @@ def get_functions_to_optimize(
|
|||
):
|
||||
found_function = fn
|
||||
if found_function is None:
|
||||
if is_lsp:
|
||||
return functions, 0, None
|
||||
exit_with_message(
|
||||
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
|
||||
)
|
||||
|
|
@ -206,12 +214,12 @@ def get_functions_to_optimize(
|
|||
logger.info("Finding all functions modified in the current git diff ...")
|
||||
console.rule()
|
||||
ph("cli-optimizing-git-diff")
|
||||
functions = get_functions_within_git_diff()
|
||||
functions = get_functions_within_git_diff(uncommitted_changes=False)
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
|
||||
)
|
||||
|
||||
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
logger.info(f"!lsp|Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
if optimize_all:
|
||||
three_min_in_ns = int(1.8e11)
|
||||
console.rule()
|
||||
|
|
@ -222,9 +230,18 @@ def get_functions_to_optimize(
|
|||
return filtered_modified_functions, functions_count, trace_file_path
|
||||
|
||||
|
||||
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
|
||||
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False)
|
||||
modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
|
||||
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
|
||||
return get_functions_within_lines(modified_lines)
|
||||
|
||||
|
||||
def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
|
||||
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
|
||||
return get_functions_within_lines(modified_lines)
|
||||
|
||||
|
||||
def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]:
|
||||
functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
for path_str, lines_in_file in modified_lines.items():
|
||||
path = Path(path_str)
|
||||
if not path.exists():
|
||||
|
|
@ -238,14 +255,14 @@ def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
|
|||
continue
|
||||
function_lines = FunctionVisitor(file_path=str(path))
|
||||
wrapper.visit(function_lines)
|
||||
modified_functions[str(path)] = [
|
||||
functions[str(path)] = [
|
||||
function_to_optimize
|
||||
for function_to_optimize in function_lines.functions
|
||||
if (start_line := function_to_optimize.starting_line) is not None
|
||||
and (end_line := function_to_optimize.ending_line) is not None
|
||||
and any(start_line <= line <= end_line for line in lines_in_file)
|
||||
]
|
||||
return modified_functions
|
||||
return functions
|
||||
|
||||
|
||||
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
||||
|
|
@ -468,6 +485,10 @@ def was_function_previously_optimized(
|
|||
Tuple of (filtered_functions_dict, remaining_count)
|
||||
|
||||
"""
|
||||
if is_LSP_enabled():
|
||||
# was_function_previously_optimized is for the checking the optimization duplicates in the github action, no need to do this in the LSP mode
|
||||
return False
|
||||
|
||||
# Check optimization status if repository info is provided
|
||||
# already_optimized_count = 0
|
||||
try:
|
||||
|
|
@ -594,20 +615,22 @@ def filter_functions(
|
|||
|
||||
if not disable_logs:
|
||||
log_info = {
|
||||
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
|
||||
f"{site_packages_removed_count} site-package function{'s' if site_packages_removed_count != 1 else ''}": site_packages_removed_count,
|
||||
f"{malformed_paths_count} non-importable file path{'s' if malformed_paths_count != 1 else ''}": malformed_paths_count,
|
||||
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
|
||||
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
|
||||
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
|
||||
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
|
||||
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
|
||||
"Test functions removed": (test_functions_removed_count, "yellow"),
|
||||
"Site-package functions removed": (site_packages_removed_count, "magenta"),
|
||||
"Non-importable file paths": (malformed_paths_count, "red"),
|
||||
"Functions outside module-root": (non_modules_removed_count, "cyan"),
|
||||
"Files from ignored paths": (ignore_paths_removed_count, "blue"),
|
||||
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
|
||||
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
|
||||
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
|
||||
}
|
||||
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
|
||||
if log_string:
|
||||
logger.info(f"Ignoring: {log_string}")
|
||||
tree = Tree(Text("Ignored functions and files", style="bold"))
|
||||
for label, (count, color) in log_info.items():
|
||||
if count > 0:
|
||||
tree.add(Text(f"{label}: {count}", style=color))
|
||||
if len(tree.children) > 0:
|
||||
console.print(tree)
|
||||
console.rule()
|
||||
|
||||
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ if __name__ == "__main__":
|
|||
|
||||
try:
|
||||
exitcode = pytest.main(
|
||||
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
|
||||
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"],
|
||||
plugins=[PytestCollectionPlugin()],
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to collect tests: {e!s}")
|
||||
|
|
|
|||
|
|
@ -1,4 +0,0 @@
|
|||
# Silence the console module to prevent stdout pollution
|
||||
from codeflash.cli_cmds.console import console
|
||||
|
||||
console.quiet = True
|
||||
|
|
@ -4,17 +4,34 @@ import contextlib
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import git
|
||||
from pygls import uris
|
||||
|
||||
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
|
||||
from codeflash.cli_cmds.cli import process_pyproject_config
|
||||
from codeflash.cli_cmds.console import code_print
|
||||
from codeflash.code_utils.git_worktree_utils import (
|
||||
create_diff_patch_from_worktree,
|
||||
get_patches_metadata,
|
||||
overwrite_patch_metadata,
|
||||
)
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc
|
||||
from codeflash.discovery.functions_to_optimize import (
|
||||
filter_functions,
|
||||
get_functions_inside_a_commit,
|
||||
get_functions_within_git_diff,
|
||||
)
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
|
||||
from codeflash.lsp.server import CodeflashLanguageServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from lsprotocol import types
|
||||
|
||||
from codeflash.models.models import GeneratedTestsList, OptimizationSet
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -28,7 +45,65 @@ class FunctionOptimizationParams:
|
|||
functionName: str # noqa: N815
|
||||
|
||||
|
||||
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
|
||||
@dataclass
|
||||
class ProvideApiKeyParams:
|
||||
api_key: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidateProjectParams:
|
||||
root_path_abs: str
|
||||
config_file: Optional[str] = None
|
||||
skip_validation: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnPatchAppliedParams:
|
||||
patch_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizableFunctionsInCommitParams:
|
||||
commit_hash: str
|
||||
|
||||
|
||||
# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
|
||||
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
|
||||
|
||||
|
||||
@server.feature("getOptimizableFunctionsInCurrentDiff")
|
||||
def get_functions_in_current_git_diff(
|
||||
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
|
||||
) -> dict[str, str | dict[str, list[str]]]:
|
||||
functions = get_functions_within_git_diff(uncommitted_changes=True)
|
||||
file_to_qualified_names = _group_functions_by_file(server, functions)
|
||||
return {"functions": file_to_qualified_names, "status": "success"}
|
||||
|
||||
|
||||
@server.feature("getOptimizableFunctionsInCommit")
|
||||
def get_functions_in_commit(
|
||||
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
|
||||
) -> dict[str, str | dict[str, list[str]]]:
|
||||
functions = get_functions_inside_a_commit(params.commit_hash)
|
||||
file_to_qualified_names = _group_functions_by_file(server, functions)
|
||||
return {"functions": file_to_qualified_names, "status": "success"}
|
||||
|
||||
|
||||
def _group_functions_by_file(
|
||||
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
|
||||
) -> dict[str, list[str]]:
|
||||
file_to_funcs_to_optimize, _ = filter_functions(
|
||||
modified_functions=functions,
|
||||
tests_root=server.optimizer.test_cfg.tests_root,
|
||||
ignore_paths=[],
|
||||
project_root=server.optimizer.args.project_root,
|
||||
module_root=server.optimizer.args.module_root,
|
||||
previous_checkpoint_functions={},
|
||||
)
|
||||
file_to_qualified_names: dict[str, list[str]] = {
|
||||
str(path): [f.qualified_name for f in funcs] for path, funcs in file_to_funcs_to_optimize.items()
|
||||
}
|
||||
return file_to_qualified_names
|
||||
|
||||
|
||||
@server.feature("getOptimizableFunctions")
|
||||
|
|
@ -37,45 +112,24 @@ def get_optimizable_functions(
|
|||
) -> dict[str, list[str]]:
|
||||
file_path = Path(uris.to_fs_path(params.textDocument.uri))
|
||||
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
|
||||
if not server.optimizer:
|
||||
return {"status": "error", "message": "optimizer not initialized"}
|
||||
|
||||
# Save original args to restore later
|
||||
original_file = getattr(server.optimizer.args, "file", None)
|
||||
original_function = getattr(server.optimizer.args, "function", None)
|
||||
original_checkpoint = getattr(server.optimizer.args, "previous_checkpoint_functions", None)
|
||||
server.optimizer.args.file = file_path
|
||||
server.optimizer.args.function = None # Always get ALL functions, not just one
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
|
||||
server.show_message_log(f"Original args - file: {original_file}, function: {original_function}", "Info")
|
||||
server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
|
||||
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
try:
|
||||
# Set temporary args for this request only
|
||||
server.optimizer.args.file = file_path
|
||||
server.optimizer.args.function = None # Always get ALL functions, not just one
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
path_to_qualified_names = {}
|
||||
for functions in optimizable_funcs.values():
|
||||
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]
|
||||
|
||||
server.show_message_log("Calling get_optimizable_functions...", "Info")
|
||||
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
path_to_qualified_names = {}
|
||||
for path, functions in optimizable_funcs.items():
|
||||
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
|
||||
|
||||
server.show_message_log(
|
||||
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
|
||||
)
|
||||
return path_to_qualified_names
|
||||
finally:
|
||||
# Restore original args to prevent state corruption
|
||||
if original_file is not None:
|
||||
server.optimizer.args.file = original_file
|
||||
if original_function is not None:
|
||||
server.optimizer.args.function = original_function
|
||||
else:
|
||||
server.optimizer.args.function = None
|
||||
if original_checkpoint is not None:
|
||||
server.optimizer.args.previous_checkpoint_functions = original_checkpoint
|
||||
|
||||
server.show_message_log(
|
||||
f"Restored args - file: {server.optimizer.args.file}, function: {server.optimizer.args.function}", "Info"
|
||||
)
|
||||
server.show_message_log(
|
||||
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
|
||||
)
|
||||
return path_to_qualified_names
|
||||
|
||||
|
||||
@server.feature("initializeFunctionOptimization")
|
||||
|
|
@ -85,18 +139,28 @@ def initialize_function_optimization(
|
|||
file_path = Path(uris.to_fs_path(params.textDocument.uri))
|
||||
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
|
||||
|
||||
# IMPORTANT: Store the specific function for optimization, but don't corrupt global state
|
||||
if server.optimizer is None:
|
||||
_initialize_optimizer_if_api_key_is_valid(server)
|
||||
|
||||
server.optimizer.worktree_mode()
|
||||
|
||||
original_args, _ = server.optimizer.original_args_and_test_cfg
|
||||
|
||||
server.optimizer.args.function = params.functionName
|
||||
server.optimizer.args.file = file_path
|
||||
original_relative_file_path = file_path.relative_to(original_args.project_root)
|
||||
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
|
||||
server.show_message_log(
|
||||
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
|
||||
)
|
||||
|
||||
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
|
||||
if not optimizable_funcs:
|
||||
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
if count == 0:
|
||||
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
|
||||
return {"functionName": params.functionName, "status": "not found", "args": None}
|
||||
server.cleanup_the_optimizer()
|
||||
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
|
||||
|
||||
fto = optimizable_funcs.popitem()[1][0]
|
||||
server.optimizer.current_function_being_optimized = fto
|
||||
|
|
@ -104,191 +168,311 @@ def initialize_function_optimization(
|
|||
return {"functionName": params.functionName, "status": "success"}
|
||||
|
||||
|
||||
@server.feature("discoverFunctionTests")
|
||||
def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
|
||||
fto = server.optimizer.current_function_being_optimized
|
||||
optimizable_funcs = {fto.file_path: [fto]}
|
||||
def _find_pyproject_toml(workspace_path: str) -> Path | None:
|
||||
workspace_path_obj = Path(workspace_path)
|
||||
max_depth = 2
|
||||
base_depth = len(workspace_path_obj.parts)
|
||||
|
||||
devnull_writer = open(os.devnull, "w") # noqa
|
||||
with contextlib.redirect_stdout(devnull_writer):
|
||||
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
|
||||
for root, dirs, files in os.walk(workspace_path_obj):
|
||||
depth = len(Path(root).parts) - base_depth
|
||||
if depth > max_depth:
|
||||
# stop going deeper into this branch
|
||||
dirs.clear()
|
||||
continue
|
||||
|
||||
server.optimizer.discovered_tests = function_to_tests
|
||||
|
||||
return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests}
|
||||
if "pyproject.toml" in files:
|
||||
file_path = Path(root) / "pyproject.toml"
|
||||
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
|
||||
for line in f:
|
||||
if line.strip() == "[tool.codeflash]":
|
||||
return file_path.resolve()
|
||||
return None
|
||||
|
||||
|
||||
@server.feature("prepareOptimization")
|
||||
def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
|
||||
current_function = server.optimizer.current_function_being_optimized
|
||||
# should be called the first thing to initialize and validate the project
|
||||
@server.feature("initProject")
|
||||
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
|
||||
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
|
||||
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
pyproject_toml_path: Path | None = getattr(params, "config_file", None)
|
||||
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
current_function,
|
||||
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=current_function.file_path,
|
||||
)
|
||||
if server.args is None:
|
||||
if pyproject_toml_path is not None:
|
||||
# if there is a config file provided use it
|
||||
server.prepare_optimizer_arguments(pyproject_toml_path)
|
||||
else:
|
||||
# otherwise look for it
|
||||
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
|
||||
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
|
||||
if pyproject_toml_path:
|
||||
server.prepare_optimizer_arguments(pyproject_toml_path)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No pyproject.toml found in workspace.",
|
||||
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
if getattr(params, "skip_validation", False):
|
||||
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
server.show_message_log("Validating project...", "Info")
|
||||
config = is_valid_pyproject_toml(pyproject_toml_path)
|
||||
if config is None:
|
||||
server.show_message_log("pyproject.toml is not valid", "Error")
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
|
||||
}
|
||||
|
||||
return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"}
|
||||
args = process_args(server)
|
||||
repo = git.Repo(args.module_root, search_parent_directories=True)
|
||||
if repo.bare:
|
||||
return {"status": "error", "message": "Repository is in bare state"}
|
||||
|
||||
try:
|
||||
_ = repo.head.commit
|
||||
except Exception:
|
||||
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
|
||||
|
||||
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
|
||||
|
||||
|
||||
@server.feature("generateTests")
|
||||
def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
|
||||
function_optimizer = server.optimizer.current_function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
def _initialize_optimizer_if_api_key_is_valid(
|
||||
server: CodeflashLanguageServer, api_key: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
user_id = get_user_id(api_key=api_key)
|
||||
if user_id is None:
|
||||
return {"status": "error", "message": "api key not found or invalid"}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
if user_id.startswith("Error: "):
|
||||
error_msg = user_id[7:]
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
test_setup_result = function_optimizer.generate_and_instrument_tests(
|
||||
code_context, should_run_experiment=should_run_experiment
|
||||
)
|
||||
if not is_successful(test_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
|
||||
generated_tests_list: GeneratedTestsList
|
||||
optimizations_set: OptimizationSet
|
||||
generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap()
|
||||
new_args = process_args(server)
|
||||
server.optimizer = Optimizer(new_args)
|
||||
return {"status": "success", "user_id": user_id}
|
||||
|
||||
generated_tests: list[str] = [
|
||||
generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests
|
||||
]
|
||||
optimizations_dict = {
|
||||
candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation}
|
||||
for candidate in optimizations_set.control + optimizations_set.experiment
|
||||
}
|
||||
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "success",
|
||||
"message": {"generated_tests": generated_tests, "optimizations": optimizations_dict},
|
||||
}
|
||||
def process_args(server: CodeflashLanguageServer) -> Namespace:
|
||||
if server.args_processed_before:
|
||||
return server.args
|
||||
new_args = process_pyproject_config(server.args)
|
||||
server.args = new_args
|
||||
server.args_processed_before = True
|
||||
return new_args
|
||||
|
||||
|
||||
@server.feature("apiKeyExistsAndValid")
|
||||
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
|
||||
try:
|
||||
return _initialize_optimizer_if_api_key_is_valid(server)
|
||||
except Exception:
|
||||
return {"status": "error", "message": "something went wrong while validating the api key"}
|
||||
|
||||
|
||||
@server.feature("provideApiKey")
|
||||
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
|
||||
try:
|
||||
api_key = params.api_key
|
||||
if not api_key.startswith("cf-"):
|
||||
return {"status": "error", "message": "Api key is not valid"}
|
||||
|
||||
# clear cache to ensure the new api key is used
|
||||
get_codeflash_api_key.cache_clear()
|
||||
get_user_id.cache_clear()
|
||||
|
||||
init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key)
|
||||
if init_result["status"] == "error":
|
||||
return {"status": "error", "message": "Api key is not valid"}
|
||||
|
||||
user_id = init_result["user_id"]
|
||||
result = save_api_key_to_rc(api_key)
|
||||
if not is_successful(result):
|
||||
return {"status": "error", "message": result.failure()}
|
||||
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
|
||||
except Exception:
|
||||
return {"status": "error", "message": "something went wrong while saving the api key"}
|
||||
|
||||
|
||||
@server.feature("retrieveSuccessfulOptimizations")
|
||||
def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
|
||||
metadata = get_patches_metadata()
|
||||
return {"status": "success", "patches": metadata["patches"]}
|
||||
|
||||
|
||||
@server.feature("onPatchApplied")
|
||||
def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]:
|
||||
# first remove the patch from the metadata
|
||||
metadata = get_patches_metadata()
|
||||
|
||||
deleted_patch_file = None
|
||||
new_patches = []
|
||||
for patch in metadata["patches"]:
|
||||
if patch["id"] == params.patch_id:
|
||||
deleted_patch_file = patch["patch_path"]
|
||||
continue
|
||||
new_patches.append(patch)
|
||||
|
||||
# then remove the patch file
|
||||
if deleted_patch_file:
|
||||
overwrite_patch_metadata(new_patches)
|
||||
patch_path = Path(deleted_patch_file)
|
||||
patch_path.unlink(missing_ok=True)
|
||||
return {"status": "success"}
|
||||
return {"status": "error", "message": "Patch not found"}
|
||||
|
||||
|
||||
@server.feature("performFunctionOptimization")
|
||||
@server.thread()
|
||||
def perform_function_optimization( # noqa: PLR0911
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationParams
|
||||
) -> dict[str, str]:
|
||||
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
|
||||
current_function = server.optimizer.current_function_being_optimized
|
||||
try:
|
||||
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
|
||||
current_function = server.optimizer.current_function_being_optimized
|
||||
|
||||
if not current_function:
|
||||
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "No function currently being optimized",
|
||||
}
|
||||
if not current_function:
|
||||
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "No function currently being optimized",
|
||||
}
|
||||
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
|
||||
if not module_prep_result:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "Failed to prepare module for optimization",
|
||||
}
|
||||
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
current_function,
|
||||
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=current_function.file_path,
|
||||
function_to_tests=server.optimizer.discovered_tests or {},
|
||||
)
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
|
||||
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
|
||||
|
||||
test_setup_result = function_optimizer.generate_and_instrument_tests(
|
||||
code_context, should_run_experiment=should_run_experiment
|
||||
)
|
||||
if not is_successful(test_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
|
||||
(
|
||||
generated_tests,
|
||||
function_to_concolic_tests,
|
||||
concolic_test_str,
|
||||
optimizations_set,
|
||||
generated_test_paths,
|
||||
generated_perf_test_paths,
|
||||
instrumented_unittests_created_for_function,
|
||||
original_conftest_content,
|
||||
) = test_setup_result.unwrap()
|
||||
|
||||
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
|
||||
code_context=code_context,
|
||||
original_helper_code=original_helper_code,
|
||||
function_to_concolic_tests=function_to_concolic_tests,
|
||||
generated_test_paths=generated_test_paths,
|
||||
generated_perf_test_paths=generated_perf_test_paths,
|
||||
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
|
||||
original_conftest_content=original_conftest_content,
|
||||
)
|
||||
|
||||
if not is_successful(baseline_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
|
||||
|
||||
(
|
||||
function_to_optimize_qualified_name,
|
||||
function_to_all_tests,
|
||||
original_code_baseline,
|
||||
test_functions_to_remove,
|
||||
file_path_to_helper_classes,
|
||||
) = baseline_setup_result.unwrap()
|
||||
|
||||
best_optimization = function_optimizer.find_and_process_best_optimization(
|
||||
optimizations_set=optimizations_set,
|
||||
code_context=code_context,
|
||||
original_code_baseline=original_code_baseline,
|
||||
original_helper_code=original_helper_code,
|
||||
file_path_to_helper_classes=file_path_to_helper_classes,
|
||||
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
|
||||
function_to_all_tests=function_to_all_tests,
|
||||
generated_tests=generated_tests,
|
||||
test_functions_to_remove=test_functions_to_remove,
|
||||
concolic_test_str=concolic_test_str,
|
||||
)
|
||||
|
||||
if not best_optimization:
|
||||
server.show_message_log(
|
||||
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
current_function,
|
||||
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=current_function.file_path,
|
||||
function_to_tests={},
|
||||
)
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
|
||||
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
|
||||
|
||||
code_print(
|
||||
code_context.read_writable_code.flat,
|
||||
file_name=current_function.file_path,
|
||||
function_name=current_function.function_name,
|
||||
)
|
||||
|
||||
optimizable_funcs = {current_function.file_path: [current_function]}
|
||||
|
||||
devnull_writer = open(os.devnull, "w") # noqa
|
||||
with contextlib.redirect_stdout(devnull_writer):
|
||||
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
|
||||
function_optimizer.function_to_tests = function_to_tests
|
||||
|
||||
test_setup_result = function_optimizer.generate_and_instrument_tests(
|
||||
code_context, should_run_experiment=should_run_experiment
|
||||
)
|
||||
if not is_successful(test_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
|
||||
(
|
||||
generated_tests,
|
||||
function_to_concolic_tests,
|
||||
concolic_test_str,
|
||||
optimizations_set,
|
||||
generated_test_paths,
|
||||
generated_perf_test_paths,
|
||||
instrumented_unittests_created_for_function,
|
||||
original_conftest_content,
|
||||
) = test_setup_result.unwrap()
|
||||
|
||||
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
|
||||
code_context=code_context,
|
||||
original_helper_code=original_helper_code,
|
||||
function_to_concolic_tests=function_to_concolic_tests,
|
||||
generated_test_paths=generated_test_paths,
|
||||
generated_perf_test_paths=generated_perf_test_paths,
|
||||
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
|
||||
original_conftest_content=original_conftest_content,
|
||||
)
|
||||
|
||||
if not is_successful(baseline_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
|
||||
|
||||
(
|
||||
function_to_optimize_qualified_name,
|
||||
function_to_all_tests,
|
||||
original_code_baseline,
|
||||
test_functions_to_remove,
|
||||
file_path_to_helper_classes,
|
||||
) = baseline_setup_result.unwrap()
|
||||
|
||||
best_optimization = function_optimizer.find_and_process_best_optimization(
|
||||
optimizations_set=optimizations_set,
|
||||
code_context=code_context,
|
||||
original_code_baseline=original_code_baseline,
|
||||
original_helper_code=original_helper_code,
|
||||
file_path_to_helper_classes=file_path_to_helper_classes,
|
||||
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
|
||||
function_to_all_tests=function_to_all_tests,
|
||||
generated_tests=generated_tests,
|
||||
test_functions_to_remove=test_functions_to_remove,
|
||||
concolic_test_str=concolic_test_str,
|
||||
)
|
||||
|
||||
if not best_optimization:
|
||||
server.show_message_log(
|
||||
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
|
||||
)
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
|
||||
}
|
||||
|
||||
# generate a patch for the optimization
|
||||
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
|
||||
|
||||
speedup = original_code_baseline.runtime / best_optimization.runtime
|
||||
|
||||
# get the original file path in the actual project (not in the worktree)
|
||||
original_args, _ = server.optimizer.original_args_and_test_cfg
|
||||
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
|
||||
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
|
||||
|
||||
metadata = create_diff_patch_from_worktree(
|
||||
server.optimizer.current_worktree,
|
||||
relative_file_paths,
|
||||
metadata_input={
|
||||
"fto_name": function_to_optimize_qualified_name,
|
||||
"explanation": best_optimization.explanation_v2,
|
||||
"file_path": str(original_file_path),
|
||||
"speedup": speedup,
|
||||
},
|
||||
)
|
||||
|
||||
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
|
||||
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
|
||||
"status": "success",
|
||||
"message": "Optimization completed successfully",
|
||||
"extra": f"Speedup: {speedup:.2f}x faster",
|
||||
"patch_file": metadata["patch_path"],
|
||||
"patch_id": metadata["id"],
|
||||
"explanation": best_optimization.explanation_v2,
|
||||
}
|
||||
|
||||
optimized_source = best_optimization.candidate.source_code
|
||||
speedup = original_code_baseline.runtime / best_optimization.runtime
|
||||
|
||||
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
|
||||
|
||||
# CRITICAL: Clear the function filter after optimization to prevent state corruption
|
||||
server.optimizer.args.function = None
|
||||
server.show_message_log("Cleared function filter to prevent state corruption", "Info")
|
||||
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "success",
|
||||
"message": "Optimization completed successfully",
|
||||
"extra": f"Speedup: {speedup:.2f}x faster",
|
||||
"optimization": optimized_source,
|
||||
}
|
||||
finally:
|
||||
server.cleanup_the_optimizer()
|
||||
|
|
|
|||
61
codeflash/lsp/helpers.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import os
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
from rich.tree import Tree
|
||||
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
_double_quote_pat = re.compile(r'"(.*?)"')
|
||||
_single_quote_pat = re.compile(r"'(.*?)'")
|
||||
worktree_path_regex = re.compile(r'\/[^"]*worktrees\/[^"]\S*')
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_LSP_enabled() -> bool:
|
||||
return os.getenv("CODEFLASH_LSP", default="false").lower() == "true"
|
||||
|
||||
|
||||
def tree_to_markdown(tree: Tree, level: int = 0) -> str:
|
||||
"""Convert a rich Tree into a Markdown bullet list."""
|
||||
indent = " " * level
|
||||
if level == 0:
|
||||
lines: list[str] = [f"{indent}### {tree.label}"]
|
||||
else:
|
||||
lines: list[str] = [f"{indent}- {tree.label}"]
|
||||
for child in tree.children:
|
||||
lines.extend(tree_to_markdown(child, level + 1).splitlines())
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def report_to_markdown_table(report: dict[TestType, dict[str, int]], title: str) -> str:
|
||||
lines = ["| Test Type | Passed ✅ | Failed ❌ |", "|-----------|--------|--------|"]
|
||||
for test_type in TestType:
|
||||
if test_type is TestType.INIT_STATE_TEST:
|
||||
continue
|
||||
passed = report[test_type]["passed"]
|
||||
failed = report[test_type]["failed"]
|
||||
if passed == 0 and failed == 0:
|
||||
continue
|
||||
lines.append(f"| {test_type.to_name()} | {passed} | {failed} |")
|
||||
table = "\n".join(lines)
|
||||
if title:
|
||||
return f"### {title}\n{table}"
|
||||
return table
|
||||
|
||||
|
||||
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
path_in_msg = worktree_path_regex.search(msg)
|
||||
if path_in_msg:
|
||||
last_part_of_path = path_in_msg.group(0).split("/")[-1]
|
||||
if highlight:
|
||||
last_part_of_path = f"`{last_part_of_path}`"
|
||||
return msg.replace(path_in_msg.group(0), last_part_of_path)
|
||||
return msg
|
||||
|
||||
|
||||
def replace_quotes_with_backticks(text: str) -> str:
|
||||
# double-quoted strings
|
||||
text = _double_quote_pat.sub(r"`\1`", text)
|
||||
# single-quoted strings
|
||||
return _single_quote_pat.sub(r"`\1`", text)
|
||||
143
codeflash/lsp/lsp_logger.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
|
||||
|
||||
root_logger = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMessageTags:
|
||||
# always set default values for message tags
|
||||
not_lsp: bool = False # !lsp (prevent the message from being sent to the LSP)
|
||||
lsp: bool = False # lsp (lsp only)
|
||||
force_lsp: bool = False # force_lsp (you can use this to force a message to be sent to the LSP even if the level is not supported)
|
||||
loading: bool = False # loading (you can use this to indicate that the message is a loading message)
|
||||
highlight: bool = False # highlight (you can use this to highlight the message by wrapping it in ``)
|
||||
h1: bool = False # h1
|
||||
h2: bool = False # h2
|
||||
h3: bool = False # h3
|
||||
h4: bool = False # h4
|
||||
|
||||
|
||||
def add_highlight_tags(msg: str, tags: LspMessageTags) -> str:
|
||||
if tags.highlight:
|
||||
return "`" + msg + "`"
|
||||
return msg
|
||||
|
||||
|
||||
def add_heading_tags(msg: str, tags: LspMessageTags) -> str:
|
||||
if tags.h1:
|
||||
return "# " + msg
|
||||
if tags.h2:
|
||||
return "## " + msg
|
||||
if tags.h3:
|
||||
return "### " + msg
|
||||
if tags.h4:
|
||||
return "#### " + msg
|
||||
return msg
|
||||
|
||||
|
||||
def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
|
||||
delimiter = "|"
|
||||
first_delim_idx = msg.find(delimiter)
|
||||
if first_delim_idx != -1 and msg.count(delimiter) == 1:
|
||||
tags_str = msg[:first_delim_idx]
|
||||
content = msg[first_delim_idx + 1 :]
|
||||
tags = {tag.strip() for tag in tags_str.split(",")}
|
||||
message_tags = LspMessageTags()
|
||||
# manually check and set to avoid repeated membership tests
|
||||
if "lsp" in tags:
|
||||
message_tags.lsp = True
|
||||
if "!lsp" in tags:
|
||||
message_tags.not_lsp = True
|
||||
if "force_lsp" in tags:
|
||||
message_tags.force_lsp = True
|
||||
if "loading" in tags:
|
||||
message_tags.loading = True
|
||||
if "highlight" in tags:
|
||||
message_tags.highlight = True
|
||||
if "h1" in tags:
|
||||
message_tags.h1 = True
|
||||
if "h2" in tags:
|
||||
message_tags.h2 = True
|
||||
if "h3" in tags:
|
||||
message_tags.h3 = True
|
||||
if "h4" in tags:
|
||||
message_tags.h4 = True
|
||||
return message_tags, content
|
||||
|
||||
return LspMessageTags(), msg
|
||||
|
||||
|
||||
supported_lsp_log_levels = ("info", "debug")
|
||||
|
||||
|
||||
def enhanced_log(
|
||||
msg: str | Any, # noqa: ANN401
|
||||
actual_log_fn: Callable[[str, Any, Any], None],
|
||||
level: str,
|
||||
*args: Any, # noqa: ANN401
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
) -> None:
|
||||
if not isinstance(msg, str):
|
||||
actual_log_fn(msg, *args, **kwargs)
|
||||
return
|
||||
|
||||
is_lsp_json_message = msg.startswith(message_delimiter) and msg.endswith(message_delimiter)
|
||||
is_normal_text_message = not is_lsp_json_message
|
||||
|
||||
# Extract tags only from text messages
|
||||
tags, clean_msg = extract_tags(msg) if is_normal_text_message else (LspMessageTags(), msg)
|
||||
|
||||
lsp_enabled = is_LSP_enabled()
|
||||
unsupported_level = level not in supported_lsp_log_levels
|
||||
|
||||
# ---- Normal logging path ----
|
||||
if not tags.lsp:
|
||||
if not lsp_enabled: # LSP disabled
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
if tags.not_lsp: # explicitly marked as not for LSP
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
if unsupported_level and not tags.force_lsp: # unsupported level
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
|
||||
# ---- LSP logging path ----
|
||||
if is_normal_text_message:
|
||||
clean_msg = add_heading_tags(clean_msg, tags)
|
||||
clean_msg = add_highlight_tags(clean_msg, tags)
|
||||
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
|
||||
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
|
||||
|
||||
# Configure logging to stderr for VS Code output channel
|
||||
def setup_logging() -> logging.Logger:
|
||||
global root_logger # noqa: PLW0603
|
||||
if root_logger:
|
||||
return root_logger
|
||||
# Clear any existing handlers to prevent conflicts
|
||||
logger = logging.getLogger()
|
||||
logger.handlers.clear()
|
||||
|
||||
# Set up stderr handler for VS Code output channel
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Configure root logger
|
||||
logger.addHandler(handler)
|
||||
|
||||
# Also configure the pygls logger specifically
|
||||
pygls_logger = logging.getLogger("pygls")
|
||||
pygls_logger.setLevel(logging.INFO)
|
||||
|
||||
root_logger = logger
|
||||
return logger
|
||||
96
codeflash/lsp/lsp_message.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from codeflash.lsp.helpers import replace_quotes_with_backticks, simplify_worktree_paths
|
||||
|
||||
json_primitive_types = (str, float, int, bool)
|
||||
max_code_lines_before_collapse = 45
|
||||
|
||||
# \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
|
||||
message_delimiter = "\u241f"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMessage:
|
||||
# to show a loading indicator if the operation is taking time like generating candidates or tests
|
||||
takes_time: bool = False
|
||||
|
||||
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
|
||||
if isinstance(obj, list):
|
||||
return [self._loop_through(i) for i in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._loop_through(v) for k, v in obj.items()}
|
||||
if isinstance(obj, json_primitive_types) or obj is None:
|
||||
return obj
|
||||
if isinstance(obj, Path):
|
||||
return obj.as_posix()
|
||||
return str(obj)
|
||||
|
||||
def type(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def serialize(self) -> str:
|
||||
data = self._loop_through(asdict(self))
|
||||
ordered = {"type": self.type(), **data}
|
||||
return message_delimiter + json.dumps(ordered) + message_delimiter
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspTextMessage(LspMessage):
|
||||
text: str = ""
|
||||
|
||||
def type(self) -> str:
|
||||
return "text"
|
||||
|
||||
def serialize(self) -> str:
|
||||
self.text = simplify_worktree_paths(self.text)
|
||||
self.text = replace_quotes_with_backticks(self.text)
|
||||
return super().serialize()
|
||||
|
||||
|
||||
# TODO: use it instead of the lspcodemessage to display multiple files in the same message
|
||||
class LspMultiCodeMessage(LspMessage):
|
||||
files: list[LspCodeMessage]
|
||||
|
||||
def type(self) -> str:
|
||||
return "code"
|
||||
|
||||
def serialize(self) -> str:
|
||||
return super().serialize()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspCodeMessage(LspMessage):
|
||||
code: str = ""
|
||||
file_name: Optional[Path] = None
|
||||
function_name: Optional[str] = None
|
||||
collapsed: bool = False
|
||||
lines_count: Optional[int] = None
|
||||
|
||||
def type(self) -> str:
|
||||
return "code"
|
||||
|
||||
def serialize(self) -> str:
|
||||
code_lines_length = len(self.code.split("\n"))
|
||||
self.lines_count = code_lines_length
|
||||
if code_lines_length > max_code_lines_before_collapse:
|
||||
self.collapsed = True
|
||||
self.file_name = simplify_worktree_paths(str(self.file_name), highlight=False)
|
||||
return super().serialize()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMarkdownMessage(LspMessage):
|
||||
markdown: str = ""
|
||||
|
||||
def type(self) -> str:
|
||||
return "markdown"
|
||||
|
||||
def serialize(self) -> str:
|
||||
self.markdown = simplify_worktree_paths(self.markdown)
|
||||
self.markdown = replace_quotes_with_backticks(self.markdown)
|
||||
return super().serialize()
|
||||
|
|
@ -1,60 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
|
||||
from pygls import uris
|
||||
from pygls.protocol import LanguageServerProtocol, lsp_method
|
||||
from lsprotocol.types import LogMessageParams, MessageType
|
||||
from pygls.protocol import LanguageServerProtocol
|
||||
from pygls.server import LanguageServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lsprotocol.types import InitializeParams, InitializeResult
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
|
||||
_server: CodeflashLanguageServer
|
||||
|
||||
@lsp_method(INITIALIZE)
|
||||
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
|
||||
server = self._server
|
||||
initialize_result: InitializeResult = super().lsp_initialize(params)
|
||||
|
||||
workspace_uri = params.root_uri
|
||||
if workspace_uri:
|
||||
workspace_path = uris.to_fs_path(workspace_uri)
|
||||
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
|
||||
if pyproject_toml_path:
|
||||
server.initialize_optimizer(pyproject_toml_path)
|
||||
server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}")
|
||||
else:
|
||||
server.show_message("No pyproject.toml found in workspace.")
|
||||
else:
|
||||
server.show_message("No workspace URI provided.")
|
||||
|
||||
return initialize_result
|
||||
|
||||
def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
|
||||
workspace_path_obj = Path(workspace_path)
|
||||
for file_path in workspace_path_obj.rglob("pyproject.toml"):
|
||||
return file_path.resolve()
|
||||
return None
|
||||
|
||||
|
||||
class CodeflashLanguageServer(LanguageServer):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
|
||||
super().__init__(*args, **kwargs)
|
||||
self.optimizer = None
|
||||
self.optimizer: Optimizer | None = None
|
||||
self.args_processed_before: bool = False
|
||||
self.args = None
|
||||
|
||||
def initialize_optimizer(self, config_file: Path) -> None:
|
||||
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
def prepare_optimizer_arguments(self, config_file: Path) -> None:
|
||||
from codeflash.cli_cmds.cli import parse_args
|
||||
|
||||
args = parse_args()
|
||||
args.config_file = config_file
|
||||
args.no_pr = True # LSP server should not create PRs
|
||||
args = process_pyproject_config(args)
|
||||
self.optimizer = Optimizer(args)
|
||||
args.worktree = True
|
||||
self.args = args
|
||||
# avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid
|
||||
|
||||
def show_message_log(self, message: str, message_type: str) -> None:
|
||||
"""Send a log message to the client's output channel.
|
||||
|
|
@ -70,6 +47,7 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
"Warning": MessageType.Warning,
|
||||
"Error": MessageType.Error,
|
||||
"Log": MessageType.Log,
|
||||
"Debug": MessageType.Debug,
|
||||
}
|
||||
|
||||
lsp_message_type = type_mapping.get(message_type, MessageType.Info)
|
||||
|
|
@ -77,3 +55,22 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
# Send log message to client (appears in output channel)
|
||||
log_params = LogMessageParams(type=lsp_message_type, message=message)
|
||||
self.lsp.notify("window/logMessage", log_params)
|
||||
|
||||
def cleanup_the_optimizer(self) -> None:
|
||||
if not self.optimizer:
|
||||
return
|
||||
try:
|
||||
self.optimizer.cleanup_temporary_paths()
|
||||
# restore args and test cfg
|
||||
if self.optimizer.original_args_and_test_cfg:
|
||||
self.optimizer.args, self.optimizer.test_cfg = self.optimizer.original_args_and_test_cfg
|
||||
self.optimizer.args.function = None
|
||||
self.optimizer.current_worktree = None
|
||||
self.optimizer.current_function_optimizer = None
|
||||
except Exception:
|
||||
self.show_message_log("Failed to cleanup optimizer", "Error")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Gracefully shutdown the server."""
|
||||
self.cleanup_the_optimizer()
|
||||
super().shutdown()
|
||||
|
|
|
|||
|
|
@ -7,37 +7,13 @@ This script is run by the VS Code extension and is not intended to be
|
|||
executed directly by users.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from codeflash.lsp.beta import server
|
||||
|
||||
|
||||
# Configure logging to stderr for VS Code output channel
|
||||
def setup_logging() -> logging.Logger:
|
||||
# Clear any existing handlers to prevent conflicts
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setFormatter(logging.Formatter("[LSP-Server] %(asctime)s [%(levelname)s]: %(message)s"))
|
||||
|
||||
# Configure root logger
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(logging.INFO)
|
||||
|
||||
# Also configure the pygls logger specifically
|
||||
pygls_logger = logging.getLogger("pygls")
|
||||
pygls_logger.setLevel(logging.INFO)
|
||||
|
||||
return root_logger
|
||||
|
||||
from codeflash.lsp.lsp_logger import setup_logging
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set up logging
|
||||
log = setup_logging()
|
||||
log.info("Starting Codeflash Language Server...")
|
||||
root_logger = setup_logging()
|
||||
root_logger.info("Starting Codeflash Language Server...")
|
||||
|
||||
# Start the language server
|
||||
server.start_io()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
|
|||
from codeflash.cli_cmds.console import paneled_text
|
||||
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.version_check import check_for_newer_minor_version
|
||||
from codeflash.telemetry import posthog_cf
|
||||
from codeflash.telemetry.sentry import init_sentry
|
||||
|
||||
|
|
@ -21,12 +22,15 @@ def main() -> None:
|
|||
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
|
||||
)
|
||||
args = parse_args()
|
||||
|
||||
# Check for newer version for all commands
|
||||
check_for_newer_minor_version()
|
||||
|
||||
if args.command:
|
||||
disable_telemetry = False
|
||||
if args.config_file and Path.exists(args.config_file):
|
||||
pyproject_config, _ = parse_config_file(args.config_file)
|
||||
disable_telemetry = pyproject_config.get("disable_telemetry", False)
|
||||
else:
|
||||
disable_telemetry = False
|
||||
init_sentry(not disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not disable_telemetry)
|
||||
args.func()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,10 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from rich.tree import Tree
|
||||
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
|
||||
from codeflash.lsp.lsp_message import LspMarkdownMessage
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
|
@ -19,7 +22,7 @@ from re import Pattern
|
|||
from typing import Annotated, Optional, cast
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr, ValidationError
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
|
|
@ -91,6 +94,7 @@ class FunctionSource:
|
|||
|
||||
class BestOptimization(BaseModel):
|
||||
candidate: OptimizedCandidate
|
||||
explanation_v2: Optional[str] = None
|
||||
helper_functions: list[FunctionSource]
|
||||
code_context: CodeOptimizationContext
|
||||
runtime: int
|
||||
|
|
@ -157,12 +161,51 @@ class CodeString(BaseModel):
|
|||
file_path: Optional[Path] = None
|
||||
|
||||
|
||||
def get_code_block_splitter(file_path: Path) -> str:
|
||||
return f"# file: {file_path}"
|
||||
|
||||
|
||||
markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL)
|
||||
|
||||
|
||||
class CodeStringsMarkdown(BaseModel):
|
||||
code_strings: list[CodeString] = []
|
||||
_cache: dict = PrivateAttr(default_factory=dict)
|
||||
|
||||
@property
|
||||
def flat(self) -> str:
|
||||
"""Returns the combined Python module from all code blocks.
|
||||
|
||||
Each block is prefixed by a file path comment to indicate its origin.
|
||||
This representation is syntactically valid Python code.
|
||||
|
||||
Returns:
|
||||
str: The concatenated code of all blocks with file path annotations.
|
||||
|
||||
!! Important !!:
|
||||
Avoid parsing the flat code with multiple files,
|
||||
parsing may result in unexpected behavior.
|
||||
|
||||
|
||||
"""
|
||||
if self._cache.get("flat") is not None:
|
||||
return self._cache["flat"]
|
||||
self._cache["flat"] = "\n".join(
|
||||
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
|
||||
)
|
||||
return self._cache["flat"]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
"""Returns the markdown representation of the code, including the file path where possible."""
|
||||
"""Returns a Markdown-formatted string containing all code blocks.
|
||||
|
||||
Each block is enclosed in a triple-backtick code block with an optional
|
||||
file path suffix (e.g., ```python:filename.py).
|
||||
|
||||
Returns:
|
||||
str: Markdown representation of the code blocks.
|
||||
|
||||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
f"```python{':' + code_string.file_path.as_posix() if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
|
||||
|
|
@ -170,10 +213,48 @@ class CodeStringsMarkdown(BaseModel):
|
|||
]
|
||||
)
|
||||
|
||||
def file_to_path(self) -> dict[str, str]:
|
||||
"""Return a dictionary mapping file paths to their corresponding code blocks.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: Mapping from file path (as string) to code.
|
||||
|
||||
"""
|
||||
if self._cache.get("file_to_path") is not None:
|
||||
return self._cache["file_to_path"]
|
||||
self._cache["file_to_path"] = {
|
||||
str(code_string.file_path): code_string.code for code_string in self.code_strings
|
||||
}
|
||||
return self._cache["file_to_path"]
|
||||
|
||||
@staticmethod
|
||||
def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
|
||||
"""Parse a Markdown string into a CodeStringsMarkdown object.
|
||||
|
||||
Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance.
|
||||
|
||||
Args:
|
||||
markdown_code (str): The Markdown-formatted string to parse.
|
||||
|
||||
Returns:
|
||||
CodeStringsMarkdown: Parsed object containing code blocks.
|
||||
|
||||
"""
|
||||
matches = markdown_pattern.findall(markdown_code)
|
||||
results = CodeStringsMarkdown()
|
||||
try:
|
||||
for file_path, code in matches:
|
||||
path = file_path.strip()
|
||||
results.code_strings.append(CodeString(code=code, file_path=Path(path)))
|
||||
return results # noqa: TRY300
|
||||
except ValidationError:
|
||||
# if any file is invalid, return an empty CodeStringsMarkdown for the entire context
|
||||
return CodeStringsMarkdown()
|
||||
|
||||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
testgen_context_code: str = ""
|
||||
read_writable_code: str = Field(min_length=1)
|
||||
read_writable_code: CodeStringsMarkdown
|
||||
read_only_context_code: str = ""
|
||||
hashing_code_context: str = ""
|
||||
hashing_code_context_hash: str = ""
|
||||
|
|
@ -272,7 +353,7 @@ class TestsInFile:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class OptimizedCandidate:
|
||||
source_code: str
|
||||
source_code: CodeStringsMarkdown
|
||||
explanation: str
|
||||
optimization_id: str
|
||||
|
||||
|
|
@ -404,27 +485,6 @@ class VerificationType(str, Enum):
|
|||
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
|
||||
|
||||
|
||||
class TestType(Enum):
|
||||
EXISTING_UNIT_TEST = 1
|
||||
INSPIRED_REGRESSION = 2
|
||||
GENERATED_REGRESSION = 3
|
||||
REPLAY_TEST = 4
|
||||
CONCOLIC_COVERAGE_TEST = 5
|
||||
INIT_STATE_TEST = 6
|
||||
|
||||
def to_name(self) -> str:
|
||||
if self is TestType.INIT_STATE_TEST:
|
||||
return ""
|
||||
names = {
|
||||
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
||||
}
|
||||
return names[self]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InvocationId:
|
||||
test_module_path: str # The fully qualified name of the test module
|
||||
|
|
@ -480,7 +540,7 @@ class FunctionTestInvocation:
|
|||
return f"{self.loop_index}:{self.id.id()}"
|
||||
|
||||
|
||||
class TestResults(BaseModel):
|
||||
class TestResults(BaseModel): # noqa: PLW1641
|
||||
# don't modify these directly, use the add method
|
||||
# also we don't support deletion of test results elements - caution is advised
|
||||
test_results: list[FunctionTestInvocation] = []
|
||||
|
|
@ -566,6 +626,13 @@ class TestResults(BaseModel):
|
|||
@staticmethod
|
||||
def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
|
||||
tree = Tree(title)
|
||||
|
||||
if is_LSP_enabled():
|
||||
# Build markdown table
|
||||
markdown = report_to_markdown_table(report, title)
|
||||
lsp_log(LspMarkdownMessage(markdown=markdown))
|
||||
return tree
|
||||
|
||||
for test_type in TestType:
|
||||
if test_type is TestType.INIT_STATE_TEST:
|
||||
continue
|
||||
|
|
|
|||
22
codeflash/models/test_type.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class TestType(Enum):
|
||||
EXISTING_UNIT_TEST = 1
|
||||
INSPIRED_REGRESSION = 2
|
||||
GENERATED_REGRESSION = 3
|
||||
REPLAY_TEST = 4
|
||||
CONCOLIC_COVERAGE_TEST = 5
|
||||
INIT_STATE_TEST = 6
|
||||
|
||||
def to_name(self) -> str:
|
||||
if self is TestType.INIT_STATE_TEST:
|
||||
return ""
|
||||
names = {
|
||||
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
||||
}
|
||||
return names[self]
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
|
@ -14,6 +15,13 @@ from codeflash.cli_cmds.console import console, logger, progress_bar
|
|||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
|
||||
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
|
||||
from codeflash.code_utils.git_utils import check_running_in_git_repo
|
||||
from codeflash.code_utils.git_worktree_utils import (
|
||||
create_detached_worktree,
|
||||
create_diff_patch_from_worktree,
|
||||
create_worktree_snapshot_commit,
|
||||
remove_worktree,
|
||||
)
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import ValidCode
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
|
@ -48,6 +56,9 @@ class Optimizer:
|
|||
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
|
||||
self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP
|
||||
self.current_function_optimizer: FunctionOptimizer | None = None
|
||||
self.current_worktree: Path | None = None
|
||||
self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None
|
||||
self.patch_files: list[Path] = []
|
||||
|
||||
def run_benchmarks(
|
||||
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
|
||||
|
|
@ -181,7 +192,7 @@ class Optimizer:
|
|||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
|
||||
from codeflash.code_utils.static_analysis import analyze_imported_modules
|
||||
|
||||
logger.info(f"Examining file {original_module_path!s}…")
|
||||
logger.info(f"loading|Examining file {original_module_path!s}")
|
||||
console.rule()
|
||||
|
||||
original_module_code: str = original_module_path.read_text(encoding="utf8")
|
||||
|
|
@ -227,14 +238,15 @@ class Optimizer:
|
|||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
|
||||
console.rule()
|
||||
start_time = time.time()
|
||||
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
|
||||
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
|
||||
)
|
||||
console.rule()
|
||||
logger.info(
|
||||
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
|
||||
)
|
||||
with progress_bar("Discovering existing function tests..."):
|
||||
start_time = time.time()
|
||||
function_to_tests, num_discovered_tests, num_discovered_replay_tests = discover_unit_tests(
|
||||
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize
|
||||
)
|
||||
console.rule()
|
||||
logger.info(
|
||||
f"Discovered {num_discovered_tests} existing unit tests and {num_discovered_replay_tests} replay tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
|
||||
)
|
||||
console.rule()
|
||||
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
|
||||
return function_to_tests, num_discovered_tests
|
||||
|
|
@ -252,6 +264,10 @@ class Optimizer:
|
|||
if self.args.no_draft and is_pr_draft():
|
||||
logger.warning("PR is in draft mode, skipping optimization")
|
||||
return
|
||||
|
||||
if self.args.worktree:
|
||||
self.worktree_mode()
|
||||
|
||||
cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root))
|
||||
|
||||
function_optimizer = None
|
||||
|
|
@ -260,7 +276,6 @@ class Optimizer:
|
|||
file_to_funcs_to_optimize, num_optimizable_functions
|
||||
)
|
||||
optimizations_found: int = 0
|
||||
function_iterator_count: int = 0
|
||||
if self.args.test_framework == "pytest":
|
||||
self.test_cfg.concolic_test_root_dir = Path(
|
||||
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
|
||||
|
|
@ -296,8 +311,8 @@ class Optimizer:
|
|||
except Exception as e:
|
||||
logger.debug(f"Could not rank functions in {original_module_path}: {e}")
|
||||
|
||||
for function_to_optimize in functions_to_optimize:
|
||||
function_iterator_count += 1
|
||||
for i, function_to_optimize in enumerate(functions_to_optimize):
|
||||
function_iterator_count = i + 1
|
||||
logger.info(
|
||||
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
|
||||
f"{function_to_optimize.qualified_name}"
|
||||
|
|
@ -327,15 +342,39 @@ class Optimizer:
|
|||
)
|
||||
if is_successful(best_optimization):
|
||||
optimizations_found += 1
|
||||
# create a diff patch for successful optimization
|
||||
if self.current_worktree:
|
||||
best_opt = best_optimization.unwrap()
|
||||
read_writable_code = best_opt.code_context.read_writable_code
|
||||
relative_file_paths = [
|
||||
code_string.file_path for code_string in read_writable_code.code_strings
|
||||
]
|
||||
metadata = create_diff_patch_from_worktree(
|
||||
self.current_worktree,
|
||||
relative_file_paths,
|
||||
fto_name=function_to_optimize.qualified_name,
|
||||
metadata_input={},
|
||||
)
|
||||
self.patch_files.append(metadata["patch_path"])
|
||||
if i < len(functions_to_optimize) - 1:
|
||||
create_worktree_snapshot_commit(
|
||||
self.current_worktree,
|
||||
f"Optimizing {functions_to_optimize[i + 1].qualified_name}",
|
||||
)
|
||||
else:
|
||||
logger.warning(best_optimization.failure())
|
||||
console.rule()
|
||||
continue
|
||||
finally:
|
||||
if function_optimizer is not None:
|
||||
function_optimizer.executor.shutdown(wait=True)
|
||||
function_optimizer.cleanup_generated_files()
|
||||
|
||||
ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found})
|
||||
if len(self.patch_files) > 0:
|
||||
logger.info(
|
||||
f"Created {len(self.patch_files)} patch(es) ({[str(patch_path) for patch_path in self.patch_files]})"
|
||||
)
|
||||
if self.functions_checkpoint:
|
||||
self.functions_checkpoint.cleanup()
|
||||
if hasattr(self.args, "command") and self.args.command == "optimize":
|
||||
|
|
@ -381,14 +420,62 @@ class Optimizer:
|
|||
cleanup_paths([self.replay_tests_dir])
|
||||
|
||||
def cleanup_temporary_paths(self) -> None:
|
||||
if self.current_function_optimizer:
|
||||
self.current_function_optimizer.cleanup_generated_files()
|
||||
|
||||
if hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir.cleanup()
|
||||
del get_run_tmp_file.tmpdir
|
||||
|
||||
if self.current_worktree:
|
||||
remove_worktree(self.current_worktree)
|
||||
return
|
||||
|
||||
if self.current_function_optimizer:
|
||||
self.current_function_optimizer.cleanup_generated_files()
|
||||
cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir])
|
||||
|
||||
def worktree_mode(self) -> None:
|
||||
if self.current_worktree:
|
||||
return
|
||||
|
||||
if check_running_in_git_repo(self.args.module_root):
|
||||
worktree_dir = create_detached_worktree(self.args.module_root)
|
||||
if worktree_dir is None:
|
||||
logger.warning("Failed to create worktree. Skipping optimization.")
|
||||
return
|
||||
self.current_worktree = worktree_dir
|
||||
self.mutate_args_for_worktree_mode(worktree_dir)
|
||||
# make sure the tests dir is created in the worktree, this can happen if the original tests dir is empty
|
||||
Path(self.args.tests_root).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
|
||||
saved_args = copy.deepcopy(self.args)
|
||||
saved_test_cfg = copy.deepcopy(self.test_cfg)
|
||||
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
|
||||
|
||||
project_root = self.args.project_root
|
||||
module_root = self.args.module_root
|
||||
relative_module_root = module_root.relative_to(project_root)
|
||||
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
|
||||
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
|
||||
relative_benchmarks_root = (
|
||||
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
|
||||
)
|
||||
|
||||
self.args.module_root = worktree_dir / relative_module_root
|
||||
self.args.project_root = worktree_dir
|
||||
self.args.test_project_root = worktree_dir
|
||||
self.args.tests_root = worktree_dir / relative_tests_root
|
||||
if relative_benchmarks_root:
|
||||
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
|
||||
|
||||
self.test_cfg.project_root_path = worktree_dir
|
||||
self.test_cfg.tests_project_rootdir = worktree_dir
|
||||
self.test_cfg.tests_root = worktree_dir / relative_tests_root
|
||||
if relative_benchmarks_root:
|
||||
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
|
||||
|
||||
if relative_optimized_file is not None:
|
||||
self.args.file = worktree_dir / relative_optimized_file
|
||||
|
||||
|
||||
def run_with_args(args: Namespace) -> None:
|
||||
optimizer = None
|
||||
|
|
|
|||
|
|
@ -10,12 +10,7 @@ from codeflash.api import cfapi
|
|||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_replacer import is_zero_diff
|
||||
from codeflash.code_utils.git_utils import (
|
||||
check_and_push_branch,
|
||||
get_current_branch,
|
||||
get_repo_owner_and_name,
|
||||
git_root_dir,
|
||||
)
|
||||
from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name
|
||||
from codeflash.code_utils.github_utils import github_pr_url
|
||||
from codeflash.code_utils.tabulate import tabulate
|
||||
from codeflash.code_utils.time_utils import format_perf, format_time
|
||||
|
|
@ -34,12 +29,16 @@ def existing_tests_source_for(
|
|||
test_cfg: TestConfig,
|
||||
original_runtimes_all: dict[InvocationId, list[int]],
|
||||
optimized_runtimes_all: dict[InvocationId, list[int]],
|
||||
) -> str:
|
||||
) -> tuple[str, str, str]:
|
||||
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
|
||||
if not test_files:
|
||||
return ""
|
||||
output: str = ""
|
||||
rows = []
|
||||
return "", "", ""
|
||||
output_existing: str = ""
|
||||
output_concolic: str = ""
|
||||
output_replay: str = ""
|
||||
rows_existing = []
|
||||
rows_concolic = []
|
||||
rows_replay = []
|
||||
headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"]
|
||||
tests_root = test_cfg.tests_root
|
||||
original_tests_to_runtimes: dict[Path, dict[str, int]] = {}
|
||||
|
|
@ -99,28 +98,79 @@ def existing_tests_source_for(
|
|||
* 100
|
||||
)
|
||||
if greater:
|
||||
rows.append(
|
||||
if "__replay_test_" in str(print_filename):
|
||||
rows_replay.append(
|
||||
[
|
||||
f"`{print_filename}::{qualified_name}`",
|
||||
f"{print_original_runtime}",
|
||||
f"{print_optimized_runtime}",
|
||||
f"{perf_gain}%⚠️",
|
||||
]
|
||||
)
|
||||
elif "codeflash_concolic" in str(print_filename):
|
||||
rows_concolic.append(
|
||||
[
|
||||
f"`{print_filename}::{qualified_name}`",
|
||||
f"{print_original_runtime}",
|
||||
f"{print_optimized_runtime}",
|
||||
f"{perf_gain}%⚠️",
|
||||
]
|
||||
)
|
||||
else:
|
||||
rows_existing.append(
|
||||
[
|
||||
f"`{print_filename}::{qualified_name}`",
|
||||
f"{print_original_runtime}",
|
||||
f"{print_optimized_runtime}",
|
||||
f"{perf_gain}%⚠️",
|
||||
]
|
||||
)
|
||||
elif "__replay_test_" in str(print_filename):
|
||||
rows_replay.append(
|
||||
[
|
||||
f"`{print_filename}::{qualified_name}`",
|
||||
f"{print_original_runtime}",
|
||||
f"{print_optimized_runtime}",
|
||||
f"⚠️{perf_gain}%",
|
||||
f"{perf_gain}%✅",
|
||||
]
|
||||
)
|
||||
elif "codeflash_concolic" in str(print_filename):
|
||||
rows_concolic.append(
|
||||
[
|
||||
f"`{print_filename}::{qualified_name}`",
|
||||
f"{print_original_runtime}",
|
||||
f"{print_optimized_runtime}",
|
||||
f"{perf_gain}%✅",
|
||||
]
|
||||
)
|
||||
else:
|
||||
rows.append(
|
||||
rows_existing.append(
|
||||
[
|
||||
f"`{print_filename}::{qualified_name}`",
|
||||
f"{print_original_runtime}",
|
||||
f"{print_optimized_runtime}",
|
||||
f"✅{perf_gain}%",
|
||||
f"{perf_gain}%✅",
|
||||
]
|
||||
)
|
||||
output += tabulate( # type: ignore[no-untyped-call]
|
||||
headers=headers, tabular_data=rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
|
||||
output_existing += tabulate( # type: ignore[no-untyped-call]
|
||||
headers=headers, tabular_data=rows_existing, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
|
||||
)
|
||||
output += "\n"
|
||||
return output
|
||||
output_existing += "\n"
|
||||
if len(rows_existing) == 0:
|
||||
output_existing = ""
|
||||
output_concolic += tabulate( # type: ignore[no-untyped-call]
|
||||
headers=headers, tabular_data=rows_concolic, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
|
||||
)
|
||||
output_concolic += "\n"
|
||||
if len(rows_concolic) == 0:
|
||||
output_concolic = ""
|
||||
output_replay += tabulate( # type: ignore[no-untyped-call]
|
||||
headers=headers, tabular_data=rows_replay, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
|
||||
)
|
||||
output_replay += "\n"
|
||||
if len(rows_replay) == 0:
|
||||
output_replay = ""
|
||||
return output_existing, output_replay, output_concolic
|
||||
|
||||
|
||||
def check_create_pr(
|
||||
|
|
@ -131,6 +181,9 @@ def check_create_pr(
|
|||
generated_original_test_source: str,
|
||||
function_trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str,
|
||||
concolic_tests: str,
|
||||
root_dir: Path,
|
||||
git_remote: Optional[str] = None,
|
||||
) -> None:
|
||||
pr_number: Optional[int] = env_utils.get_pr_number()
|
||||
|
|
@ -139,9 +192,9 @@ def check_create_pr(
|
|||
if pr_number is not None:
|
||||
logger.info(f"Suggesting changes to PR #{pr_number} ...")
|
||||
owner, repo = get_repo_owner_and_name(git_repo)
|
||||
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
|
||||
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
|
||||
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
)
|
||||
for p in original_code
|
||||
|
|
@ -171,6 +224,8 @@ def check_create_pr(
|
|||
generated_tests=generated_original_test_source,
|
||||
trace_id=function_trace_id,
|
||||
coverage_message=coverage_message,
|
||||
replay_tests=replay_tests,
|
||||
concolic_tests=concolic_tests,
|
||||
)
|
||||
if response.ok:
|
||||
logger.info(f"Suggestions were successfully made to PR #{pr_number}")
|
||||
|
|
@ -188,10 +243,10 @@ def check_create_pr(
|
|||
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
|
||||
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
|
||||
return
|
||||
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
|
||||
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
|
||||
base_branch = get_current_branch()
|
||||
build_file_changes = {
|
||||
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
|
||||
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
|
||||
oldContent=original_code[p], newContent=new_code[p]
|
||||
)
|
||||
for p in original_code
|
||||
|
|
@ -218,6 +273,8 @@ def check_create_pr(
|
|||
generated_tests=generated_original_test_source,
|
||||
trace_id=function_trace_id,
|
||||
coverage_message=coverage_message,
|
||||
replay_tests=replay_tests,
|
||||
concolic_tests=concolic_tests,
|
||||
)
|
||||
if response.ok:
|
||||
pr_id = response.text
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils import env_utils
|
||||
|
|
@ -9,7 +9,7 @@ from codeflash.code_utils.config_consts import (
|
|||
MIN_IMPROVEMENT_THRESHOLD,
|
||||
MIN_TESTCASE_PASSED_THRESHOLD,
|
||||
)
|
||||
from codeflash.models.models import TestType
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
|
||||
|
|
@ -29,7 +29,8 @@ def speedup_critic(
|
|||
candidate_result: OptimizedCandidateResult,
|
||||
original_code_runtime: int,
|
||||
best_runtime_until_now: int | None,
|
||||
disable_gh_action_noise: Optional[bool] = None,
|
||||
*,
|
||||
disable_gh_action_noise: bool = False,
|
||||
) -> bool:
|
||||
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
|
||||
|
||||
|
|
@ -39,10 +40,8 @@ def speedup_critic(
|
|||
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
|
||||
"""
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
if not disable_gh_action_noise:
|
||||
in_github_actions_mode = bool(env_utils.get_pr_number())
|
||||
if in_github_actions_mode:
|
||||
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
|
||||
if not disable_gh_action_noise and env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
|
||||
|
||||
perf_gain = performance_gain(
|
||||
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from rich.console import Console
|
|||
from rich.table import Table
|
||||
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.models.models import BenchmarkDetail, TestResults
|
||||
|
||||
|
||||
|
|
@ -40,7 +41,7 @@ class Explanation:
|
|||
def speedup_pct(self) -> str:
|
||||
return f"{self.speedup * 100:,.0f}%"
|
||||
|
||||
def to_console_string(self) -> str:
|
||||
def __str__(self) -> str:
|
||||
# TODO: After doing the best optimization, remove the test cases that errored on the new code, because they might be failing because of syntax errors and such.
|
||||
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
|
||||
original_runtime_human = humanize_runtime(self.original_runtime_ns)
|
||||
|
|
@ -85,6 +86,9 @@ class Explanation:
|
|||
console.print(table)
|
||||
benchmark_info = cast("StringIO", console.file).getvalue() + "\n" # Cast for mypy
|
||||
|
||||
test_report = self.winning_behavior_test_results.get_test_pass_fail_report_by_type()
|
||||
test_report_str = TestResults.report_to_string(test_report)
|
||||
|
||||
return (
|
||||
f"Optimized {self.function_name} in {self.file_path}\n"
|
||||
f"{self.perf_improvement_line}\n"
|
||||
|
|
@ -92,8 +96,13 @@ class Explanation:
|
|||
+ (benchmark_info if benchmark_info else "")
|
||||
+ self.raw_explanation_message
|
||||
+ " \n\n"
|
||||
+ "The new optimized code was tested for correctness. The results are listed below.\n"
|
||||
+ f"{TestResults.report_to_string(self.winning_behavior_test_results.get_test_pass_fail_report_by_type())}\n"
|
||||
+ (
|
||||
# in the lsp (extension) we display the test results before the optimization summary
|
||||
""
|
||||
if is_LSP_enabled()
|
||||
else "The new optimized code was tested for correctness. The results are listed below.\n"
|
||||
+ test_report_str
|
||||
)
|
||||
)
|
||||
|
||||
def explanation_message(self) -> str:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from codeflash.cli_cmds.console import console
|
|||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.tracing.pytest_parallelization import pytest_split
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
|
@ -86,51 +87,97 @@ def main(args: Namespace | None = None) -> ArgumentParser:
|
|||
config, found_config_path = parse_config_file(parsed_args.codeflash_config)
|
||||
project_root = project_root_from_module_root(Path(config["module_root"]), found_config_path)
|
||||
if len(unknown_args) > 0:
|
||||
args_dict = {
|
||||
"functions": parsed_args.only_functions,
|
||||
"disable": False,
|
||||
"project_root": str(project_root),
|
||||
"max_function_count": parsed_args.max_function_count,
|
||||
"timeout": parsed_args.tracer_timeout,
|
||||
"progname": unknown_args[0],
|
||||
"config": config,
|
||||
"module": parsed_args.module,
|
||||
}
|
||||
try:
|
||||
result_pickle_file_path = get_run_tmp_file("tracer_results_file.pkl")
|
||||
args_dict = {
|
||||
"result_pickle_file_path": str(result_pickle_file_path),
|
||||
"output": str(parsed_args.outfile),
|
||||
"functions": parsed_args.only_functions,
|
||||
"disable": False,
|
||||
"project_root": str(project_root),
|
||||
"max_function_count": parsed_args.max_function_count,
|
||||
"timeout": parsed_args.tracer_timeout,
|
||||
"command": " ".join(sys.argv),
|
||||
"progname": unknown_args[0],
|
||||
"config": config,
|
||||
"module": parsed_args.module,
|
||||
}
|
||||
pytest_splits = []
|
||||
test_paths = []
|
||||
replay_test_paths = []
|
||||
if parsed_args.module and unknown_args[0] == "pytest":
|
||||
pytest_splits, test_paths = pytest_split(unknown_args[1:])
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
Path(__file__).parent / "tracing" / "tracing_new_process.py",
|
||||
*sys.argv,
|
||||
json.dumps(args_dict),
|
||||
],
|
||||
cwd=Path.cwd(),
|
||||
check=False,
|
||||
)
|
||||
try:
|
||||
with result_pickle_file_path.open(mode="rb") as f:
|
||||
data = pickle.load(f)
|
||||
except Exception:
|
||||
console.print("❌ Failed to trace. Exiting...")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
result_pickle_file_path.unlink(missing_ok=True)
|
||||
if len(pytest_splits) > 1:
|
||||
processes = []
|
||||
test_paths_set = set(test_paths)
|
||||
result_pickle_file_paths = []
|
||||
for i, test_split in enumerate(pytest_splits, start=1):
|
||||
result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl"))
|
||||
result_pickle_file_paths.append(result_pickle_file_path)
|
||||
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
|
||||
outpath = parsed_args.outfile
|
||||
outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}"
|
||||
args_dict["output"] = str(outpath)
|
||||
updated_sys_argv = []
|
||||
for elem in sys.argv:
|
||||
if elem in test_paths_set:
|
||||
updated_sys_argv.extend(test_split)
|
||||
else:
|
||||
updated_sys_argv.append(elem)
|
||||
args_dict["command"] = " ".join(updated_sys_argv)
|
||||
processes.append(
|
||||
subprocess.Popen(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
Path(__file__).parent / "tracing" / "tracing_new_process.py",
|
||||
*updated_sys_argv,
|
||||
json.dumps(args_dict),
|
||||
],
|
||||
cwd=Path.cwd(),
|
||||
)
|
||||
)
|
||||
for process in processes:
|
||||
process.wait()
|
||||
for result_pickle_file_path in result_pickle_file_paths:
|
||||
try:
|
||||
with result_pickle_file_path.open(mode="rb") as f:
|
||||
data = pickle.load(f)
|
||||
replay_test_paths.append(str(data["replay_test_file_path"]))
|
||||
except Exception:
|
||||
console.print("❌ Failed to trace. Exiting...")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
result_pickle_file_path.unlink(missing_ok=True)
|
||||
else:
|
||||
result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl"))
|
||||
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
|
||||
args_dict["output"] = str(parsed_args.outfile)
|
||||
args_dict["command"] = " ".join(sys.argv)
|
||||
|
||||
replay_test_path = data["replay_test_file_path"]
|
||||
if not parsed_args.trace_only and replay_test_path is not None:
|
||||
subprocess.run(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
Path(__file__).parent / "tracing" / "tracing_new_process.py",
|
||||
*sys.argv,
|
||||
json.dumps(args_dict),
|
||||
],
|
||||
cwd=Path.cwd(),
|
||||
check=False,
|
||||
)
|
||||
try:
|
||||
with result_pickle_file_path.open(mode="rb") as f:
|
||||
data = pickle.load(f)
|
||||
replay_test_paths.append(str(data["replay_test_file_path"]))
|
||||
except Exception:
|
||||
console.print("❌ Failed to trace. Exiting...")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
result_pickle_file_path.unlink(missing_ok=True)
|
||||
if not parsed_args.trace_only and replay_test_paths:
|
||||
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
|
||||
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
|
||||
from codeflash.cli_cmds.console import paneled_text
|
||||
from codeflash.telemetry import posthog_cf
|
||||
from codeflash.telemetry.sentry import init_sentry
|
||||
|
||||
sys.argv = ["codeflash", "--replay-test", str(replay_test_path)]
|
||||
|
||||
sys.argv = ["codeflash", "--replay-test", *replay_test_paths]
|
||||
args = parse_args()
|
||||
paneled_text(
|
||||
CODEFLASH_LOGO,
|
||||
|
|
@ -150,8 +197,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
|
|||
# Delete the trace file and the replay test file if they exist
|
||||
if outfile:
|
||||
outfile.unlink(missing_ok=True)
|
||||
if replay_test_path:
|
||||
replay_test_path.unlink(missing_ok=True)
|
||||
for replay_test_path in replay_test_paths:
|
||||
Path(replay_test_path).unlink(missing_ok=True)
|
||||
|
||||
except BrokenPipeError as exc:
|
||||
# Prevent "Exception ignored" during interpreter shutdown.
|
||||
|
|
|
|||
84
codeflash/tracing/pytest_parallelization.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from random import shuffle
|
||||
|
||||
|
||||
def pytest_split(
|
||||
arguments: list[str], num_splits: int | None = None
|
||||
) -> tuple[list[list[str]] | None, list[str] | None]:
|
||||
"""Split pytest test files from a directory into N roughly equal groups for parallel execution.
|
||||
|
||||
Args:
|
||||
arguments: List of arguments passed to pytest
|
||||
test_directory: Path to directory containing test files
|
||||
num_splits: Number of groups to split tests into. If None, uses CPU count.
|
||||
|
||||
Returns:
|
||||
List of lists, where each inner list contains test file paths for one group.
|
||||
Returns single list with all tests if number of test files < CPU cores.
|
||||
|
||||
"""
|
||||
try:
|
||||
import pytest
|
||||
|
||||
parser = pytest.Parser()
|
||||
|
||||
pytest_args = parser.parse_known_args(arguments)
|
||||
test_paths = getattr(pytest_args, "file_or_dir", None)
|
||||
if not test_paths:
|
||||
return None, None
|
||||
|
||||
except ImportError:
|
||||
return None, None
|
||||
test_files = set()
|
||||
|
||||
# Find all test_*.py files recursively in the directory
|
||||
for test_path in test_paths:
|
||||
_test_path = Path(test_path)
|
||||
if not _test_path.exists():
|
||||
return None, None
|
||||
if _test_path.is_dir():
|
||||
# Find all test files matching the pattern test_*.py
|
||||
test_files.update(map(str, _test_path.rglob("test_*.py")))
|
||||
test_files.update(map(str, _test_path.rglob("*_test.py")))
|
||||
elif _test_path.is_file():
|
||||
test_files.add(str(_test_path))
|
||||
|
||||
if not test_files:
|
||||
return [[]], None
|
||||
|
||||
# Determine number of splits
|
||||
if num_splits is None:
|
||||
num_splits = os.cpu_count() or 4
|
||||
|
||||
# randomize to increase chances of all splits being balanced
|
||||
test_files = list(test_files)
|
||||
shuffle(test_files)
|
||||
|
||||
# Ensure each split has at least 4 test files
|
||||
# If we have fewer test files than 4 * num_splits, reduce num_splits
|
||||
max_possible_splits = len(test_files) // 4
|
||||
if max_possible_splits == 0:
|
||||
return test_files, test_paths
|
||||
|
||||
num_splits = min(num_splits, max_possible_splits)
|
||||
|
||||
# Calculate chunk size (round up to ensure all files are included)
|
||||
total_files = len(test_files)
|
||||
chunk_size = ceil(total_files / num_splits)
|
||||
|
||||
# Initialize result groups
|
||||
result_groups = [[] for _ in range(num_splits)]
|
||||
|
||||
# Distribute files across groups
|
||||
for i, test_file in enumerate(test_files):
|
||||
group_index = i // chunk_size
|
||||
# Ensure we don't exceed the number of groups (edge case handling)
|
||||
if group_index >= num_splits:
|
||||
group_index = num_splits - 1
|
||||
result_groups[group_index].append(test_file)
|
||||
|
||||
return result_groups, test_paths
|
||||
|
|
@ -13,6 +13,7 @@ import sys
|
|||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
||||
|
||||
|
|
@ -47,6 +48,17 @@ class FakeFrame:
|
|||
self.f_locals: dict = {}
|
||||
|
||||
|
||||
def patch_ap_scheduler() -> None:
|
||||
if find_spec("apscheduler"):
|
||||
import apscheduler.schedulers.background as bg
|
||||
import apscheduler.schedulers.blocking as bb
|
||||
from apscheduler.schedulers import base
|
||||
|
||||
bg.BackgroundScheduler.start = lambda _, *_a, **_k: None
|
||||
bb.BlockingScheduler.start = lambda _, *_a, **_k: None
|
||||
base.BaseScheduler.add_job = lambda _, *_a, **_k: None
|
||||
|
||||
|
||||
# Debug this file by simply adding print statements. This file is not meant to be debugged by the debugger.
|
||||
class Tracer:
|
||||
"""Use this class as a 'with' context manager to trace a function call.
|
||||
|
|
@ -820,6 +832,7 @@ class Tracer:
|
|||
if __name__ == "__main__":
|
||||
args_dict = json.loads(sys.argv[-1])
|
||||
sys.argv = sys.argv[1:-1]
|
||||
patch_ap_scheduler()
|
||||
if args_dict["module"]:
|
||||
import runpy
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import enum
|
|||
import math
|
||||
import re
|
||||
import types
|
||||
from collections import ChainMap, OrderedDict, deque
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
|
|
@ -70,7 +71,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
|
||||
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
|
||||
return False
|
||||
if isinstance(orig, (list, tuple)):
|
||||
if isinstance(orig, (list, tuple, deque, ChainMap)):
|
||||
if len(orig) != len(new):
|
||||
return False
|
||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||
|
|
@ -93,6 +94,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
enum.Enum,
|
||||
type,
|
||||
range,
|
||||
OrderedDict,
|
||||
),
|
||||
):
|
||||
return orig == new
|
||||
|
|
@ -233,6 +235,27 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
):
|
||||
return orig == new
|
||||
|
||||
if hasattr(orig, "__attrs_attrs__") and hasattr(new, "__attrs_attrs__"):
|
||||
orig_dict = {}
|
||||
new_dict = {}
|
||||
|
||||
for attr in orig.__attrs_attrs__:
|
||||
if attr.eq:
|
||||
attr_name = attr.name
|
||||
orig_dict[attr_name] = getattr(orig, attr_name, None)
|
||||
new_dict[attr_name] = getattr(new, attr_name, None)
|
||||
|
||||
if superset_obj:
|
||||
new_attrs_dict = {}
|
||||
for attr in new.__attrs_attrs__:
|
||||
if attr.eq:
|
||||
attr_name = attr.name
|
||||
new_attrs_dict[attr_name] = getattr(new, attr_name, None)
|
||||
return all(
|
||||
k in new_attrs_dict and comparator(v, new_attrs_dict[k], superset_obj) for k, v in orig_dict.items()
|
||||
)
|
||||
return comparator(orig_dict, new_dict, superset_obj)
|
||||
|
||||
# re.Pattern can be made better by DFA Minimization and then comparing
|
||||
if isinstance(
|
||||
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from typing import TYPE_CHECKING
|
|||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.concolic_utils import clean_concolic_tests
|
||||
from codeflash.code_utils.env_utils import is_LSP_enabled
|
||||
from codeflash.code_utils.static_analysis import has_typed_parameters
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ def parse_test_xml(
|
|||
groups = match.groups()
|
||||
if len(groups[5].split(":")) > 1:
|
||||
iteration_id = groups[5].split(":")[0]
|
||||
groups = groups[:5] + (iteration_id,)
|
||||
groups = (*groups[:5], iteration_id)
|
||||
end_matches[groups] = match
|
||||
|
||||
if not begin_matches or not begin_matches:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
|
|||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
|
||||
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
from codeflash.code_utils.coverage_utils import prepare_coverage_files
|
||||
from codeflash.models.models import TestFiles, TestType
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ def run_behavioral_tests(
|
|||
pytest_timeout: int | None = None,
|
||||
pytest_cmd: str = "pytest",
|
||||
verbose: bool = False,
|
||||
pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME,
|
||||
pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
enable_coverage: bool = False,
|
||||
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
|
||||
if test_framework == "pytest":
|
||||
|
|
@ -66,7 +66,7 @@ def run_behavioral_tests(
|
|||
"--codeflash_loops_scope=session",
|
||||
"--codeflash_min_loops=1",
|
||||
"--codeflash_max_loops=1",
|
||||
f"--codeflash_seconds={pytest_target_runtime_seconds}", # TODO : This is unnecessary, update the plugin to not ask for this
|
||||
f"--codeflash_seconds={pytest_target_runtime_seconds}",
|
||||
]
|
||||
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
|
|
@ -151,7 +151,7 @@ def run_line_profile_tests(
|
|||
cwd: Path,
|
||||
test_framework: str,
|
||||
*,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
verbose: bool = False,
|
||||
pytest_timeout: int | None = None,
|
||||
pytest_min_loops: int = 5, # noqa: ARG001
|
||||
|
|
@ -237,7 +237,7 @@ def run_benchmarking_tests(
|
|||
cwd: Path,
|
||||
test_framework: str,
|
||||
*,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
verbose: bool = False,
|
||||
pytest_timeout: int | None = None,
|
||||
pytest_min_loops: int = 5,
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
# These version placeholders will be replaced by uv-dynamic-versioning during build.
|
||||
__version__ = "0.16.0"
|
||||
__version__ = "0.16.7.post46.dev0+444ff121"
|
||||
|
|
|
|||
21
docs/.gitignore
vendored
|
|
@ -1,21 +0,0 @@
|
|||
# Dependencies
|
||||
/node_modules
|
||||
|
||||
# Production
|
||||
/build
|
||||
|
||||
# Generated files
|
||||
.docusaurus
|
||||
.cache-loader
|
||||
|
||||
# Misc
|
||||
.DS_Store
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# Website
|
||||
|
||||
This website is built using [Docusaurus](https://docusaurus.io/), a modern static website generator.
|
||||
|
||||
### Installation
|
||||
|
||||
```
|
||||
$ npm install
|
||||
```
|
||||
|
||||
### Local Development
|
||||
|
||||
```
|
||||
$ npm run start
|
||||
```
|
||||
|
||||
This command starts a local development server and opens up a browser window. Most changes are reflected live without having to restart the server.
|
||||
|
||||
### Build
|
||||
|
||||
```
|
||||
$ npm run build
|
||||
```
|
||||
|
||||
This command generates static content into the `build` directory and can be served using any static contents hosting service.
|
||||
|
||||
### Deployment
|
||||
|
||||
Deployment is done automatically via GitHub Actions when the Pull Request changes are merged to the `main` branch.
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
module.exports = {
|
||||
presets: [require.resolve('@docusaurus/core/lib/babel/preset')],
|
||||
};
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
---
|
||||
sidebar_position: 2
|
||||
title: "How Codeflash Measures Code Runtime"
|
||||
description: "Learn how Codeflash accurately measures code performance using multiple runs and minimum timing"
|
||||
icon: "stopwatch"
|
||||
sidebarTitle: "Runtime Measurement"
|
||||
keywords: ["benchmarking", "performance", "timing", "measurement", "runtime", "noise reduction"]
|
||||
---
|
||||
|
||||
# How Codeflash measures code runtime
|
||||
|
|
@ -70,7 +74,7 @@ You ask the driver to repeat the race multiple times. In this scenario, since th
|
|||
|
||||
This gives us timing data (in hours) that looks like the following.
|
||||
|
||||

|
||||

|
||||
|
||||
With 100 data points (50 per train), determining the faster train becomes more complex.
|
||||
|
||||
|
|
@ -120,7 +124,7 @@ We can only measure times between adjacent stations.
|
|||
|
||||
Here is how the timing data looks like (in hours):
|
||||
|
||||

|
||||

|
||||
|
||||
With 300 data points (50 runs × 3 segments × 2 trains) and varying conditions on each segment,
|
||||
determining the faster train becomes even more challenging.
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
---
|
||||
sidebar_position: 1
|
||||
title: "How Codeflash Works"
|
||||
description: "Understand Codeflash's generate-and-verify approach to code optimization and correctness verification"
|
||||
icon: "gear"
|
||||
sidebarTitle: "How It Works"
|
||||
keywords: ["architecture", "verification", "correctness", "testing", "optimization", "LLM", "benchmarking"]
|
||||
---
|
||||
# How Codeflash Works
|
||||
|
||||
|
|
@ -24,7 +28,7 @@ Codeflash currently only runs tests that directly call the target function in th
|
|||
To optimize code, Codeflash first gathers all necessary context from the codebase. It also line-profiles your code to understand where the bottlenecks might reside. It then calls our backend to generate several candidate optimizations. These are called "candidates" because their speed and correctness haven't been verified yet. Both properties will be verified in later steps.
|
||||
## Verification of correctness
|
||||
|
||||

|
||||

|
||||
|
||||
The goal of correctness verification is to ensure that when the new code replaces the original code, there are no behavioral changes in the code and the rest of the system. This means the replacement should be completely safe.
|
||||
|
||||
|
|
@ -53,7 +57,7 @@ Codeflash runs tests for the target function using either pytest or unittest fra
|
|||
|
||||
#### Performance benchmarking
|
||||
|
||||
Codeflash implements [several techniques](/codeflash-concepts/benchmarking.md) to measure code performance accurately. In particular, it runs multiple iterations of the code in a loop to determine the best performance with the minimum runtime. Codeflash compares the performance of the original code against the optimization, requiring at least a 10% speed improvement before considering it to be faster. This approach eliminates most runtime measurement variability, even on noisy CI systems and virtual machines. The final runtime Codeflash reports is the minimum total time it took to run all the test cases.
|
||||
Codeflash implements [several techniques](/codeflash-concepts/benchmarking) to measure code performance accurately. In particular, it runs multiple iterations of the code in a loop to determine the best performance with the minimum runtime. Codeflash compares the performance of the original code against the optimization, requiring at least a 10% speed improvement before considering it to be faster. This approach eliminates most runtime measurement variability, even on noisy CI systems and virtual machines. The final runtime Codeflash reports is the minimum total time it took to run all the test cases.
|
||||
|
||||
## Creating Pull Requests
|
||||
|
||||
|
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 7.6 KiB After Width: | Height: | Size: 7.6 KiB |
|
|
@ -1,5 +1,9 @@
|
|||
---
|
||||
sidebar_position: 5
|
||||
title: "Manual Configuration"
|
||||
description: "Configure Codeflash for your project with pyproject.toml settings and advanced options"
|
||||
icon: "gear"
|
||||
sidebarTitle: "Manual Configuration"
|
||||
keywords: ["configuration", "pyproject.toml", "setup", "settings", "pytest", "formatter"]
|
||||
---
|
||||
|
||||
# Manual Configuration
|
||||
117
docs/docs.json
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
{
|
||||
"$schema": "https://mintlify.com/docs.json",
|
||||
"theme": "aspen",
|
||||
"name": "Codeflash Documentation",
|
||||
"colors": {
|
||||
"primary": "#2563EB",
|
||||
"light": "#3B82F6",
|
||||
"dark": "#1D4ED8"
|
||||
},
|
||||
"favicon": "/favicon.ico",
|
||||
"integrations": {
|
||||
"intercom": {
|
||||
"appId": "ljxo1nzr"
|
||||
}
|
||||
},
|
||||
"navigation": {
|
||||
"tabs": [
|
||||
{
|
||||
"tab": "Documentation",
|
||||
"groups": [
|
||||
{
|
||||
"group": "🏠 Overview",
|
||||
"pages": ["index"]
|
||||
},
|
||||
{
|
||||
|
||||
"group": "🚀 Quickstart",
|
||||
"pages": [
|
||||
"getting-started/local-installation"
|
||||
] },
|
||||
{
|
||||
"group": "⚡ Optimizing with Codeflash",
|
||||
"pages": [
|
||||
"optimizing-with-codeflash/one-function",
|
||||
"optimizing-with-codeflash/trace-and-optimize",
|
||||
"optimizing-with-codeflash/codeflash-all"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "✨ Continuous Optimization",
|
||||
"pages": [
|
||||
"optimizing-with-codeflash/codeflash-github-actions",
|
||||
"optimizing-with-codeflash/benchmarking",
|
||||
"optimizing-with-codeflash/review-optimizations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"group": "🧠 Core Concepts",
|
||||
"pages": ["codeflash-concepts/how-codeflash-works", "codeflash-concepts/benchmarking"]
|
||||
},
|
||||
{
|
||||
"group": "⚙️ Configuration & Best Practices",
|
||||
"pages": ["configuration", "getting-the-best-out-of-codeflash"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"logo": {
|
||||
"light": "/images/codeflash_light.svg",
|
||||
"dark": "/images/codeflash_darkmode.svg"
|
||||
},
|
||||
"navbar": {
|
||||
"links": [
|
||||
{
|
||||
"label": "Discord",
|
||||
"href": "https://www.codeflash.ai/discord",
|
||||
"icon": "discord"
|
||||
},
|
||||
{
|
||||
"label": "GitHub",
|
||||
"href": "https://github.com/codeflash-ai/codeflash",
|
||||
"icon": "github"
|
||||
},
|
||||
{
|
||||
"label": "Blog",
|
||||
"href": "https://www.codeflash.ai/blog"
|
||||
}
|
||||
],
|
||||
"primary": {
|
||||
"type": "button",
|
||||
"label": "Try Codeflash",
|
||||
"href": "https://www.codeflash.ai"
|
||||
}
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy"]
|
||||
},
|
||||
"redirects": [
|
||||
{
|
||||
"source": "/docs/:path*",
|
||||
"destination": "/:path*"
|
||||
}
|
||||
],
|
||||
"footer": {
|
||||
"socials": {
|
||||
"discord": "https://www.codeflash.ai/discord",
|
||||
"github": "https://github.com/codeflash-ai/codeflash",
|
||||
"linkedin": "https://www.linkedin.com/company/codeflash-ai"
|
||||
},
|
||||
"links": [
|
||||
{
|
||||
"label": "Legal",
|
||||
"items": [
|
||||
{
|
||||
"label": "Privacy Policy",
|
||||
"href": "https://www.codeflash.ai/privacy-policy"
|
||||
},
|
||||
{
|
||||
"label": "Terms of Service",
|
||||
"href": "https://www.codeflash.ai/terms-of-service"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"label": "Codeflash Concepts",
|
||||
"position": 4,
|
||||
"collapsed": false
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"label": "Getting Started",
|
||||
"position": 2,
|
||||
"collapsed": false
|
||||
}
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
---
|
||||
sidebar_position: 2
|
||||
---
|
||||
|
||||
# Automate Optimization of Pull Requests
|
||||
<!--- TODO: Add more pictures to guide better --->
|
||||
|
||||
Codeflash can automatically optimize your code when new pull requests are opened.
|
||||
|
||||
To be able to scan new code for performance optimizations, Codeflash requires a GitHub action workflow to
|
||||
be installed which runs the Codeflash optimization logic on every new pull request.
|
||||
If the action workflow finds an optimization, it communicates with the Codeflash GitHub
|
||||
App through our secure servers and asks it to suggest new changes to the pull request.
|
||||
|
||||
This is the most useful way of using Codeflash, where you set it up once and all your new code gets optimized.
|
||||
So setting this up is highly recommended.
|
||||
|
||||
## Prerequisites
|
||||
- You have a Codeflash API key. If you don't have one, you can generate one from the [Codeflash Webapp](https://app.codeflash.ai/). Make sure you generate the API key with the right GitHub account that has access to the repository you want to optimize.
|
||||
- You have completed [local installation](/getting-started/local-installation) steps and have a Python project with a `pyproject.toml` file that is configured with Codeflash. If you haven't configured Codeflash for your project yet, you can do so by running `codeflash init` in the root directory of your project.
|
||||
|
||||
## Add the Codeflash GitHub Actions workflow
|
||||
|
||||
### Guided setup
|
||||
|
||||
To add the Codeflash GitHub Actions workflow to your repository, you can run the following command in your project directory:
|
||||
|
||||
```bash
|
||||
codeflash init-actions
|
||||
```
|
||||
|
||||
This will walk you through the process of adding the Codeflash GitHub Actions workflow to your repository.
|
||||
|
||||
### All Set up!
|
||||
|
||||
Open a new PR to your GitHub project, and you will now see a new actions workflow for Codeflash run. If it finds an optimization,
|
||||
codeflash-ai bot will comment on your repo with the optimization suggestions.
|
||||
|
||||
### Manual Installation (optional)
|
||||
If you prefer to install the GitHub actions manually, follow the steps below -
|
||||
|
||||
#### Add the workflow file
|
||||
Create a new file in your repository at `.github/workflows/codeflash-optimize.yaml` with the following contents:
|
||||
|
||||
|
||||
```yaml title=".github/workflows/codeflash-optimize.yaml"
|
||||
name: Codeflash
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
optimize:
|
||||
name: Optimize new code in this PR
|
||||
if: ${{ github.actor != 'codeflash-ai[bot]' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# TODO: Replace the following with your project's Python installation method
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
# TODO: Replace the following with your project's dependency installation method
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# TODO: Replace the following with your project setup method
|
||||
pip install -r requirements.txt
|
||||
pip install codeflash
|
||||
- name: Run Codeflash to optimize code
|
||||
id: optimize_code
|
||||
run: |
|
||||
codeflash
|
||||
```
|
||||
You would need to fill in the `#TODO`s in the file above to make it work. Please commit this file to your repository.
|
||||
If you use a particular Python package manager like Poetry or uv, some helpful configurations are provided below.
|
||||
|
||||
#### Config with different Python package managers
|
||||
|
||||
The yaml config above is a basic template. Here is how you can run Codeflash with the different Python package managers:
|
||||
|
||||
1. Poetry
|
||||
|
||||
```yaml
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install poetry
|
||||
poetry install --with dev
|
||||
- name: Run Codeflash to optimize code
|
||||
id: optimize_code
|
||||
run: |
|
||||
poetry env use python
|
||||
poetry run codeflash
|
||||
```
|
||||
This assumes that you install poetry with pip and have Codeflash dependency in the `dev` section of your `pyproject.toml` file.
|
||||
|
||||
2. uv
|
||||
|
||||
```yaml
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
- run: uv sync --group=dev
|
||||
- name: Run Codeflash to optimize code
|
||||
run: uv run codeflash
|
||||
```
|
||||
|
||||
#### Add your API key to your repository secrets
|
||||
|
||||
Go to your GitHub repository, click **Settings**, and click on **Secrets and
|
||||
Variables** -> **Actions** on the left sidebar.
|
||||
|
||||
Add the following secret:
|
||||
|
||||
- `CODEFLASH_API_KEY`: The API key you got from https://app.codeflash.ai/app/apikeys
|
||||
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
---
|
||||
sidebar_position: 1
|
||||
---
|
||||
|
||||
# Local Installation
|
||||
|
||||
Codeflash is installed and configured on a per-project basis.
|
||||
|
||||
You can install Codeflash locally for a project by running the following command in the project's virtual environment:
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before installing Codeflash, ensure you have:
|
||||
|
||||
1. **Python 3.9 or above** installed
|
||||
2. **A Python project** with a virtual environment
|
||||
3. **Project dependencies installed** in your virtual environment
|
||||
4. **Tests** (optional) for your code (Codeflash uses tests to verify optimizations)
|
||||
|
||||
:::important[Virtual Environment Required]
|
||||
Always install Codeflash in your project's virtual environment, not globally. Make sure your virtual environment is activated before proceeding.
|
||||
|
||||
```bash
|
||||
# Example: Activate your virtual environment
|
||||
source venv/bin/activate # On Linux/Mac
|
||||
# or
|
||||
#venv\Scripts\activate # On Windows
|
||||
```
|
||||
:::
|
||||
### Step 1: Install Codeflash
|
||||
```bash
|
||||
pip install codeflash
|
||||
```
|
||||
|
||||
:::tip[Codeflash is a Development Dependency]
|
||||
We recommend installing Codeflash as a development dependency.
|
||||
It doesn't need to be installed as part of your package requirements.
|
||||
Codeflash is intended to be used locally and as part of development workflows such as CI.
|
||||
If using pyproject.toml:
|
||||
```toml
|
||||
[tool.poetry.dependencies.dev]
|
||||
codeflash = "^latest"
|
||||
```
|
||||
Or with pip:
|
||||
```bash
|
||||
pip install --dev codeflash
|
||||
````
|
||||
:::
|
||||
|
||||
### Step 2: Generate a Codeflash API Key
|
||||
|
||||
Codeflash uses cloud-hosted AI models to optimize your code. You'll need an API key to use it.
|
||||
|
||||
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
|
||||
2. Sign up with your GitHub account (free)
|
||||
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your API key
|
||||
<!--- TODO: Do we ask for access to specific repositories here? --->
|
||||
|
||||
:::note[Free Tier Available]
|
||||
Codeflash offers a **free tier** with a limited number of optimizations per month. Perfect for trying it out or small projects!
|
||||
:::
|
||||
|
||||
### Step 3: Automatic Configuration
|
||||
|
||||
Navigate to your project's root directory (where your `pyproject.toml` file is or should be) and run:
|
||||
|
||||
```bash
|
||||
# Make sure you're in your project root
|
||||
cd /path/to/your/project
|
||||
|
||||
# Run the initialization
|
||||
codeflash init
|
||||
```
|
||||
|
||||
If you don't have a pyproject.toml file yet, the codeflash init command will ask you to create one
|
||||
|
||||
:::tip[What's pyproject.toml?]
|
||||
`pyproject.toml` is a configuration file that is used to specify build tool settings for Python projects.
|
||||
pyproject.toml is the modern replacement for setup.py and requirements.txt files.
|
||||
It's the new standard for Python package metadata.
|
||||
:::
|
||||
|
||||
When running `codeflash init`, you will see the following prompts:
|
||||
|
||||
```text
|
||||
1. Enter your Codeflash API key:
|
||||
2. Which Python module do you want me to optimize going forward? (e.g. my_module)
|
||||
3. Where are your tests located? (e.g. tests/)
|
||||
4. Which test framework do you use? (pytest/unittest)
|
||||
```
|
||||
|
||||
After you have answered these questions, Codeflash will be configured for your project.
|
||||
The configuration will be saved in the `pyproject.toml` file in the root directory of your project.
|
||||
To understand the configuration options, and set more advanced options, see the [Configuration](/configuration) page.
|
||||
|
||||
### Step 4: Install the Codeflash GitHub App
|
||||
|
||||
<!--- TODO: Justify to users Why we need the user to install Github App even in local Installation or local optimization? --->
|
||||
Finally, if you have not done so already, Codeflash will ask you to install the Github App in your repository. The Codeflash GitHub App allows access to your repository to the codeflash-ai bot to open PRs, review code, and provide optimization suggestions.
|
||||
|
||||
Please [install the Codeflash GitHub
|
||||
app](https://github.com/apps/codeflash-ai/installations/select_target) by choosing the repository you want to install
|
||||
Codeflash on.
|
||||
##
|
||||
|
||||
## Try It Out!
|
||||
|
||||
Once configured, you can start optimizing your code:
|
||||
|
||||
```bash
|
||||
# Optimize a specific function
|
||||
codeflash --file path/to/your/file.py --function function_name
|
||||
|
||||
# Or if want to optimize only locally without creating a PR
|
||||
codeflash --file path/to/your/file.py --function function_name --no-pr
|
||||
```
|
||||
|
||||
### Example Project
|
||||
|
||||
Want to see Codeflash in action? Check out our example repository:
|
||||
|
||||
🔗 [github.com/codeflash-ai/optimize-me](https://github.com/codeflash-ai/optimize-me)
|
||||
|
||||
This repo includes:
|
||||
- Sample Python code with performance issues
|
||||
- Tests for verification
|
||||
- Pre-configured `pyproject.toml`
|
||||
- Before/after optimization examples in PRs
|
||||
|
||||
Clone it and try running:
|
||||
```bash
|
||||
git clone https://github.com/codeflash-ai/optimize-me.git
|
||||
cd optimize-me
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # or venv\Scripts\activate on Windows
|
||||
pip install -r requirements.txt
|
||||
pip install codeflash
|
||||
codeflash init # Use your own API key
|
||||
codeflash --all # optimize the entire repo
|
||||
```
|
||||
|
||||
### 🔧 Troubleshooting
|
||||
|
||||
#### 📦 "Module not found" errors
|
||||
Make sure:
|
||||
- ✅ Your virtual environment is activated
|
||||
- ✅ All project dependencies are installed
|
||||
|
||||
#### 🧪 "No optimizations found" or debugging issues
|
||||
Use the `--verbose` flag for detailed output:
|
||||
```bash
|
||||
codeflash optimize --verbose
|
||||
```
|
||||
|
||||
This will show:
|
||||
- 🔍 Which functions are being analyzed
|
||||
- 🚫 Why certain functions were skipped
|
||||
- ⚠️ Detailed error messages
|
||||
- 📊 Performance analysis results
|
||||
|
||||
#### 🔍 "No tests found" errors
|
||||
Verify:
|
||||
- 📁 Your test directory path is correct in `pyproject.toml`
|
||||
- 🔍 Tests are discoverable by your test framework
|
||||
- 📝 Test files follow naming conventions (`test_*.py` for pytest)
|
||||
|
||||
|
||||
### Next Steps
|
||||
|
||||
- Learn about [Codeflash Concepts](/codeflash-concepts/how-codeflash-works)
|
||||
- Explore [Optimization workflows](/optimizing-with-codeflash/one-function)
|
||||
- Set up [GitHub Actions integration](/getting-started/codeflash-github-actions)
|
||||
- Read [configuration options](/configuration) for advanced setups
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
---
|
||||
sidebar_position: 5
|
||||
---
|
||||
|
||||
# Getting the best out of Codeflash
|
||||
|
||||
Codeflash is a powerful tool; here are our recommendations, tips and tricks on getting the best out of it. We do these ourselves, so we hope you will too!
|
||||
|
||||
### Install the Github App and actions workflow
|
||||
|
||||
After you install Codeflash on an actively developed project, [installing the GitHub App](getting-started/codeflash-github-actions) and setting up the
|
||||
GitHub Actions workflow will automatically optimize your code whenever new pull requests are opened. This ensures you get the best version of any changes you make to your code without any extra effort. We find that PRs are also the best time to review these changes, because the code is fresh in your mind.
|
||||
|
||||
### Find optimizations on your whole codebase with `codeflash --all`
|
||||
|
||||
If you have a lot of existing code, run [`codeflash --all`](optimizing-with-codeflash/codeflash-all) to discover and fix any
|
||||
slow code in your project. Codeflash will open new pull requests for any optimizations it finds, and you can review and merge them at your own pace.
|
||||
|
||||
### Find and optimize bottlenecks with the Codeflash Tracer
|
||||
|
||||
Find the best results by running [Codeflash Tracer](optimizing-with-codeflash/trace-and-optimize) on the entry point of your script before optimizing it. The Codeflash Tracer will generate a trace file and a Replay Test file that will help Codeflash understand the behavior & inputs of your functions and generate the highest quality optimizations.
|
||||
|
||||
### Review the PRs Codeflash opens
|
||||
|
||||
We're constantly improving Codeflash and the underlying AI models it uses. The state of the art changes weekly, and you can be confident the optimizer will always use the best performing LLMs to find optimizations for your code. That said, because Codeflash uses generative AI, it's still possible that the optimized code may actually have different behavior than the original code under certain conditions. Please review all the PRs that Codeflash opens to ensure that the optimized code is correct, just as you would review any other PR opened by a team member. And don't forget to send us feedback on how we can improve Codeflash - we're always listening!
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
---
|
||||
sidebar_position: 1
|
||||
slug: /
|
||||
---
|
||||
# What is Codeflash?
|
||||
|
||||
Welcome! Codeflash is an AI performance optimizer for Python code.
|
||||
Codeflash speeds up Python code by figuring out the best way to rewrite your code while verifying that the behavior of the code is unchanged.
|
||||
|
||||
The optimizations Codeflash finds are generally better algorithms, opportunities to remove wasteful compute, better logic, and utilization of more efficient library methods. Codeflash
|
||||
does not modify the architecture of your code, but it tries to find the most efficient implementation of that architecture.
|
||||
|
||||
### How does Codeflash verify correctness?
|
||||
|
||||
Codeflash verifies the correctness of the optimizations it finds by generating and running new regression tests, as well as any existing tests you may already have. Codeflash tries to ensure that your
|
||||
code behaves the same way before and after the optimization.
|
||||
This offers high confidence that the behavior of your code remains unchanged.
|
||||
|
||||
### Continuous Optimization
|
||||
|
||||
Because Codeflash is an automated process, you can install it as a GitHub action and have it optimize the new code on every pull request.
|
||||
When Codeflash finds an optimization, it will ask you to review it. It will write a detailed explanation of the changes it made, and include all relevant info like % speed increase and proofs of correctness.
|
||||
|
||||
This is a great way to ensure that your code, your team's code and your AI Agent's code are optimized for performance before it causes a performance regression. We call this *Continuous Optimization*.
|
||||
|
||||
### Features
|
||||
|
||||
<!--- TODO: Add links to the relevant sections of the documentation and style the table --->
|
||||
|
||||
| Feature | Usage | Description |
|
||||
|------------------------------------------------------------------------------|---------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Optimize a single function](optimizing-with-codeflash/one-function) | `codeflash --file path.py --function my_function` | Basic unit of optimization by asking Codeflash to optimize a particular function |
|
||||
| [Optimize an entire workflow](optimizing-with-codeflash/trace-and-optimize) | `codeflash optimize myscript.py` | End to end optimization for all the functions called in a workflow, by tracing to collect real inputs to ensure correctness and e2e performance optimization |
|
||||
| [Optimize all code in a repo](optimizing-with-codeflash/codeflash-all) | `codeflash --all` | Codeflash discovers all functions in a repo and optimizes all of them! |
|
||||
| [Optimize every new pull request](optimizing-with-codeflash/optimize-prs) | `codeflash init-actions` | Codeflash runs as a GitHub action and GitHub app and suggests optimizations for all new code in a Pull Request. |
|
||||
|
||||
|
||||
## How to use these docs
|
||||
|
||||
On the left side of the screen, you'll find the docs navigation bar.
|
||||
Start by installing Codeflash, then explore the different ways of using it to optimize your code.
|
||||
|
||||
## Questions or Feedback?
|
||||
|
||||
Your feedback will help us make codeflash better, faster. If you have any questions or feedback, use the Intercom button in the lower right, join our [Discord](https://www.codeflash.ai/discord), or drop us a note at [contact@codeflash.ai](mailto:founders@codeflash.ai) - we read every message!
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
"label": "Using Codeflash",
|
||||
"position": 3,
|
||||
"collapsed": false
|
||||
}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
---
|
||||
sidebar_position: 3
|
||||
---
|
||||
|
||||
# Optimize Pull Requests
|
||||
|
||||
Codeflash can optimize your pull requests by analyzing the changes in the pull request
|
||||
and generating optimized versions of the functions that have changed.
|
||||
|
||||
## How to optimize a pull request
|
||||
After following the setup steps in the [Automate Code Optimization with GitHub Actions](/getting-started/codeflash-github-actions) guide,
|
||||
Codeflash will automatically optimize your pull requests when they are opened.
|
||||
|
||||
If Codeflash finds any successful optimizations, it will comment on the pull request asking you to review the changes.
|
||||
|
||||

|
||||
|
||||
Codeflash can ask you to review the changes in two ways:
|
||||
### Opening a dependent pull request
|
||||
Codeflash will open a new pull request with the optimized code.
|
||||
You can review the changes in this pull request, make changes if you want, and merge it if you are satisfied with the optimizations.
|
||||
The changes will be merged back into the original pull request as a new commit.
|
||||
|
||||

|
||||
### Reviewing the changes in the original pull request
|
||||
If the suggested changes are small and only affect the modified lines, Codeflash will suggest the changes in the original pull request itself.
|
||||
You can choose to accept or reject the changes directly in the original pull request.
|
||||
The changes can be added to a batch of changes in the original pull request as a new commit.
|
||||
|
||||

|
||||
|
|
@ -1,171 +0,0 @@
|
|||
import { themes as prismThemes } from "prism-react-renderer"
|
||||
import type { Config } from "@docusaurus/types"
|
||||
import type * as Preset from "@docusaurus/preset-classic"
|
||||
|
||||
const config: Config = {
|
||||
title: "Codeflash Docs",
|
||||
tagline: "Code optimization is cool",
|
||||
favicon: "img/favicon.ico",
|
||||
|
||||
// Set the production url of your site here
|
||||
url: "https://docs.codeflash.ai",
|
||||
// Set the /<baseUrl>/ pathname under which your site is served
|
||||
// For GitHub pages deployment, it is often '/<projectName>/'
|
||||
baseUrl: "/",
|
||||
|
||||
// GitHub pages deployment config.
|
||||
// If you aren't using GitHub pages, you don't need these.
|
||||
// organizationName: 'facebook', // Usually your GitHub org/user name.
|
||||
// projectName: 'docusaurus', // Usually your repo name.
|
||||
|
||||
onBrokenLinks: "throw",
|
||||
onBrokenMarkdownLinks: "warn",
|
||||
|
||||
// Even if you don't use internationalization, you can use this field to set
|
||||
// useful metadata like html lang. For example, if your site is Chinese, you
|
||||
// may want to replace "en" with "zh-Hans".
|
||||
i18n: {
|
||||
defaultLocale: "en",
|
||||
locales: ["en"],
|
||||
},
|
||||
|
||||
scripts: [
|
||||
{
|
||||
src: "https://widget.intercom.io/widget/ljxo1nzr",
|
||||
async: true,
|
||||
onLoad: `window.Intercom('boot', {
|
||||
app_id: "ljxo1nzr"
|
||||
});`,
|
||||
},
|
||||
{
|
||||
src: "https://app.posthog.com/static/array.js",
|
||||
strategy: "afterInteractive",
|
||||
onLoad: `window.posthog = window.posthog || [];
|
||||
window.posthog.init("phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol", {
|
||||
api_host: "https://us.i.posthog.com",
|
||||
});`,
|
||||
},
|
||||
],
|
||||
presets: [
|
||||
[
|
||||
"classic",
|
||||
{
|
||||
docs: {
|
||||
sidebarPath: "./sidebars.ts",
|
||||
routeBasePath: "/",
|
||||
// Please change this to your repo.
|
||||
},
|
||||
blog: false,
|
||||
theme: {
|
||||
customCss: "./src/css/custom.css",
|
||||
},
|
||||
} satisfies Preset.Options,
|
||||
],
|
||||
],
|
||||
|
||||
themeConfig: {
|
||||
// Replace with your project's social card
|
||||
colorMode: {
|
||||
defaultMode: "dark",
|
||||
disableSwitch: false,
|
||||
respectPrefersColorScheme: false,
|
||||
},
|
||||
image: "img/codeflash_social_card.jpg",
|
||||
navbar: {
|
||||
// title: 'My Site',
|
||||
logo: {
|
||||
href: "https://codeflash.ai/",
|
||||
alt: "Codeflash Logo",
|
||||
src: "img/codeflash_light.svg",
|
||||
srcDark: "img/codeflash_darkmode.svg",
|
||||
},
|
||||
items: [
|
||||
{
|
||||
type: "docSidebar",
|
||||
sidebarId: "tutorialSidebar",
|
||||
position: "left",
|
||||
label: "Docs",
|
||||
},
|
||||
// {to: '/blog', label: 'Blog', position: 'left'},
|
||||
{
|
||||
href: "https://app.codeflash.ai/",
|
||||
label: "Get Started",
|
||||
position: "right",
|
||||
},
|
||||
],
|
||||
},
|
||||
docs : {
|
||||
sidebar: {
|
||||
autoCollapseCategories: false,
|
||||
hideable: false,
|
||||
}
|
||||
},
|
||||
footer: {
|
||||
style: "dark",
|
||||
links: [
|
||||
{
|
||||
title: "Navigation",
|
||||
items: [
|
||||
{
|
||||
label: "Home Page",
|
||||
to: "https://codeflash.ai/",
|
||||
},
|
||||
{
|
||||
label: "PyPI",
|
||||
to: "https://pypi.org/project/codeflash/",
|
||||
},
|
||||
{
|
||||
label: "Get Started",
|
||||
to: "https://app.codeflash.ai/",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: "Get in Touch",
|
||||
items: [
|
||||
{
|
||||
label: "Careers",
|
||||
to: "mailto:careers@codeflash.ai",
|
||||
},
|
||||
{
|
||||
label: "contact@codeflash.ai",
|
||||
href: "mailto:contact@codeflash.ai",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
copyright: `©2024 CodeFlash Inc.`,
|
||||
},
|
||||
prism: {
|
||||
theme: prismThemes.github,
|
||||
darkTheme: prismThemes.dracula,
|
||||
additionalLanguages: ["bash", "toml"],
|
||||
},
|
||||
algolia: {
|
||||
// The application ID provided by Algolia
|
||||
appId: 'Y1C10T0Z7E',
|
||||
|
||||
// Public API key: it is safe to commit it
|
||||
apiKey: '4d1d294b58eb97edec121c9c1c079c23',
|
||||
|
||||
indexName: 'codeflash',
|
||||
|
||||
// Optional: see doc section below
|
||||
contextualSearch: true,
|
||||
|
||||
// Optional: Algolia search parameters
|
||||
searchParameters: {},
|
||||
|
||||
// Optional: path for search page that enabled by default (`false` to disable it)
|
||||
searchPagePath: 'search',
|
||||
|
||||
// Optional: whether the insights feature is enabled or not on Docsearch (`false` by default)
|
||||
insights: true,
|
||||
|
||||
//... other Algolia params
|
||||
},
|
||||
|
||||
} satisfies Preset.ThemeConfig,
|
||||
}
|
||||
|
||||
export default config
|
||||
|
Before Width: | Height: | Size: 1.1 KiB After Width: | Height: | Size: 1.1 KiB |
225
docs/getting-started/local-installation.mdx
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
---
|
||||
title: "Local Installation"
|
||||
description: "Install and configure Codeflash for your Python project in minutes"
|
||||
icon: "download"
|
||||
---
|
||||
|
||||
Codeflash is installed and configured on a per-project basis.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before installing Codeflash, ensure you have:
|
||||
|
||||
1. **Python 3.9 or above** installed
|
||||
2. **A Python project** with a virtual environment
|
||||
3. **Project dependencies installed** in your virtual environment
|
||||
|
||||
Good to have (optional):
|
||||
1. **Unit Tests** that Codeflash uses to ensure correctness of the optimizations
|
||||
|
||||
<Warning>
|
||||
**Virtual Environment Required**
|
||||
|
||||
Always install Codeflash in your project's virtual environment, not globally. Make sure your virtual environment is activated before proceeding.
|
||||
|
||||
```bash
|
||||
source venv/bin/activate # On Linux/Mac
|
||||
# or
|
||||
venv\Scripts\activate # On Windows
|
||||
```
|
||||
</Warning>
|
||||
<Steps>
|
||||
<Step title="Install Codeflash">
|
||||
|
||||
You can install Codeflash locally for a project by running the following command in the project's virtual environment:
|
||||
```bash
|
||||
pip install codeflash
|
||||
```
|
||||
|
||||
<Tip>
|
||||
**Codeflash is a Development Dependency**
|
||||
|
||||
We recommend installing Codeflash as a development dependency.
|
||||
Codeflash is intended to be used in development workflows locally and as part of CI.
|
||||
Try to always use the latest version of Codeflash as it improves quickly.
|
||||
|
||||
<CodeGroup>
|
||||
```bash uv
|
||||
uv add --dev codeflash
|
||||
```
|
||||
|
||||
```bash poetry
|
||||
poetry add codeflash@latest --group dev
|
||||
```
|
||||
</CodeGroup>
|
||||
</Tip>
|
||||
</Step>
|
||||
|
||||
<Step title="Generate a Codeflash API Key">
|
||||
Codeflash uses cloud-hosted AI models and integrations with GitHub. You'll need an API key to authorize your access.
|
||||
|
||||
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
|
||||
2. Sign up with your GitHub account (free)
|
||||
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your API key
|
||||
|
||||
<Note>
|
||||
**Free Tier Available**
|
||||
|
||||
Codeflash offers a **free tier** with a limited number of optimizations. Perfect for trying it out on small projects!
|
||||
</Note>
|
||||
</Step>
|
||||
|
||||
<Step title="Run Automatic Configuration">
|
||||
Navigate to your project's root directory (where your `pyproject.toml` file is or should be) and run:
|
||||
|
||||
```bash
|
||||
codeflash init
|
||||
```
|
||||
|
||||
If you don't have a pyproject.toml file yet, the codeflash init command will ask you to create one
|
||||
|
||||
<Info>
|
||||
**What's pyproject.toml?**
|
||||
|
||||
`pyproject.toml` is a configuration file that is used to specify build and tool settings for Python projects.
|
||||
`pyproject.toml` is the modern replacement for setup.py and requirements.txt files.
|
||||
</Info>
|
||||
|
||||
When running `codeflash init`, you will see the following prompts:
|
||||
|
||||
```text
|
||||
1. Enter your Codeflash API key:
|
||||
2. Install the GitHub app.
|
||||
3. Which Python module do you want me to optimize going forward? (e.g. my_module)
|
||||
4. Where are your tests located? (e.g. tests/)
|
||||
5. Which test framework do you use? (pytest/unittest)
|
||||
6. Install GitHub actions for Continuous optimization?
|
||||
```
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
After you have answered these questions, the Codeflash configuration will be saved in the `pyproject.toml` file.
|
||||
To understand the configuration options, and set more advanced options, see the [Manual Configuration](/configuration) page.
|
||||
|
||||
### Step 4: Install the Codeflash GitHub App
|
||||
|
||||
{/* TODO: Justify to users Why we need the user to install Github App even in local Installation or local optimization? */}
|
||||
Finally, if you have not done so already, Codeflash will ask you to install the GitHub App in your repository.
|
||||
The Codeflash GitHub App allows access to your repository to the codeflash-ai bot to open PRs, review code, and provide optimization suggestions.
|
||||
|
||||
Please [install the Codeflash GitHub
|
||||
app](https://github.com/apps/codeflash-ai/installations/select_target) by choosing the repository you want to install
|
||||
Codeflash on.
|
||||
|
||||
## Try It Out!
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Quick Start">
|
||||
Once configured, you can start optimizing your code immediately:
|
||||
|
||||
```bash
|
||||
# Optimize a specific function
|
||||
codeflash --file path/to/your/file.py --function function_name
|
||||
|
||||
# Or optimize all functions in your codebase
|
||||
codeflash --all
|
||||
```
|
||||
|
||||
</Tab>
|
||||
|
||||
<Tab title="Optimize Example Project">
|
||||
<Card title="🚀 Try optimizing our example repository" icon="github" href="https://github.com/codeflash-ai/optimize-me">
|
||||
Want to see Codeflash in action and don't know what code to optimize? Check out our **optimize-me** repository with code ready to optimize.
|
||||
|
||||
**What's included:**
|
||||
- Sample Python code with performance issues
|
||||
- Tests for verification
|
||||
- Pre-configured `pyproject.toml`
|
||||
</Card>
|
||||
|
||||
<Steps>
|
||||
<Step title="Fork the Repository">
|
||||
Fork the [optimize-me](https://github.com/codeflash-ai/optimize-me) repo to your GitHub account by clicking "Fork" on the top of the page. This allows Codeflash to open Pull Requests with the optimizations it found on your forked repo.
|
||||
</Step>
|
||||
<Step title="Clone the Forked Repository">
|
||||
```bash
|
||||
git clone https://github.com/your_github_username/optimize-me.git
|
||||
cd optimize-me
|
||||
```
|
||||
</Step>
|
||||
|
||||
<Step title="Set Up Environment">
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
pip install codeflash
|
||||
```
|
||||
</Step>
|
||||
|
||||
<Step title="Run Codeflash">
|
||||
```bash
|
||||
codeflash init # Use your own API key
|
||||
codeflash --all # optimize the entire repo
|
||||
```
|
||||
</Step>
|
||||
</Steps>
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="📦 Module not found errors">
|
||||
Make sure:
|
||||
- ✅ Your virtual environment is activated
|
||||
- ✅ All project dependencies are installed
|
||||
|
||||
```bash
|
||||
# Verify your virtual environment is active
|
||||
which python # Should show path to your venv
|
||||
|
||||
# Install missing dependencies
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="🧪 No optimizations found or debugging issues">
|
||||
Do know that not all functions can be optimized as no optimization opportunities may exist for them. This is fine and expected.
|
||||
|
||||
To investigate further, use the `--verbose` flag for detailed output:
|
||||
```bash
|
||||
codeflash optimize --verbose
|
||||
```
|
||||
|
||||
This will show:
|
||||
- 🔍 Which functions are being analyzed
|
||||
- 🚫 Why certain functions were skipped
|
||||
- ⚠️ Detailed error messages
|
||||
- 📊 Performance analysis results
|
||||
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="🔍 No tests found errors">
|
||||
Verify:
|
||||
- 📁 Your test directory path is correct in `pyproject.toml`
|
||||
- 🔍 Tests are discoverable by your test framework
|
||||
- 📝 Test files follow naming conventions (`test_*.py` for pytest)
|
||||
|
||||
```bash
|
||||
# Test if pytest can discover your tests
|
||||
pytest --collect-only
|
||||
|
||||
# Check your pyproject.toml configuration
|
||||
cat pyproject.toml | grep -A 8 "\[tool.codeflash\]"
|
||||
```
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
|
||||
### Next Steps
|
||||
|
||||
- Learn about [Codeflash Concepts](/codeflash-concepts/how-codeflash-works)
|
||||
- Explore [Optimization workflows](/optimizing-with-codeflash/one-function)
|
||||
- Set up [Pull Request Optimization](/optimizing-with-codeflash/codeflash-github-actions)
|
||||
- Read [configuration options](/configuration) for advanced setups
|
||||
36
docs/getting-the-best-out-of-codeflash.mdx
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
---
|
||||
title: "Getting the Best Out of Codeflash"
|
||||
description: "Tips, recommendations, and best practices for maximizing Codeflash's optimization capabilities"
|
||||
icon: "lightbulb"
|
||||
sidebarTitle: "Best Practices"
|
||||
keywords: ["best practices", "tips", "github actions", "tracer", "optimization", "workflow"]
|
||||
---
|
||||
|
||||
Codeflash is a powerful tool; here are our recommendations based on how the Codeflash team and our customers use Codeflash.
|
||||
|
||||
### Install the GitHub App and actions workflow
|
||||
|
||||
After you install Codeflash on an actively developed project, [installing the GitHub Actions](optimizing-with-codeflash/codeflash-github-actions) will automatically optimize your code whenever new pull requests are opened. This ensures you get the best version of any changes you make to your code without any extra effort. We find that PRs are also the best time to review these changes, because the code is fresh in your mind.
|
||||
|
||||
### Find and optimize entire scripts with the Codeflash Tracer
|
||||
|
||||
Find the best results by running [Codeflash Optimize](optimizing-with-codeflash/trace-and-optimize) on your script to optimize it.
|
||||
This internally runs a profiler, captures inputs to all the functions your script calls, and uses those inputs to create Replay tests and benchmarks.
|
||||
The optimizations you get with this method, show you how much faster your workflow will get plus guarantee that your workflow won't break if you merge in the optimizations.
|
||||
|
||||
|
||||
### Find optimizations on your whole codebase with `codeflash --all`
|
||||
|
||||
If you have a lot of existing code, run [`codeflash --all`](optimizing-with-codeflash/codeflash-all) to discover and fix any
|
||||
slow code in your project. Codeflash will open new pull requests for any optimizations it finds, and you can review and merge them at your own pace.
|
||||
|
||||
It is first recommended to trace your tests to achieve higher quality optimizations with this approach
|
||||
|
||||
```bash
|
||||
codeflash optimize --trace-only -m pytest tests/ ; codeflash --all
|
||||
```
|
||||
|
||||
|
||||
### Review the PRs Codeflash opens
|
||||
|
||||
We're constantly improving Codeflash and the underlying AI models it uses. The state of the art changes weekly, and you can be confident the optimizer will always use the best performing LLMs to find optimizations for your code. That said, because Codeflash uses generative AI, it's still possible that the optimized code may actually have different behavior than the original code under certain conditions. Please review all the PRs that Codeflash opens to ensure that the optimized code is correct, just as you would review any other PR opened by a team member. And don't forget to send us feedback on how we can improve Codeflash - we're always listening!
|
||||
|
Before Width: | Height: | Size: 251 KiB After Width: | Height: | Size: 251 KiB |
|
Before Width: | Height: | Size: 3.4 MiB After Width: | Height: | Size: 3.4 MiB |
|
Before Width: | Height: | Size: 4.9 KiB After Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 8.8 KiB After Width: | Height: | Size: 8.8 KiB |
|
Before Width: | Height: | Size: 5 KiB After Width: | Height: | Size: 5 KiB |
BIN
docs/images/codeflash_pr_suggestion_1.png
Normal file
|
After Width: | Height: | Size: 527 KiB |
|
Before Width: | Height: | Size: 387 KiB After Width: | Height: | Size: 387 KiB |
BIN
docs/images/edited-code.png
Normal file
|
After Width: | Height: | Size: 321 KiB |
BIN
docs/images/editor.png
Normal file
|
After Width: | Height: | Size: 370 KiB |
|
Before Width: | Height: | Size: 317 KiB After Width: | Height: | Size: 317 KiB |
BIN
docs/images/review-optimizations.png
Normal file
|
After Width: | Height: | Size: 297 KiB |
|
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 24 KiB |
63
docs/index.mdx
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
---
|
||||
title: "Codeflash is an AI performance optimizer for Python code"
|
||||
icon: "rocket"
|
||||
sidebarTitle: "Overview"
|
||||
keywords: ["python", "performance", "optimization", "AI", "code analysis", "benchmarking"]
|
||||
---
|
||||
|
||||
Codeflash speeds up any Python code by figuring out the best way to rewrite it while verifying that the behavior of the code is unchanged, and verifying real speed
|
||||
gains through performance benchmarking.
|
||||
|
||||
The optimizations Codeflash finds are generally better algorithms, opportunities to remove wasteful compute, better logic, utilizing caching and utilization of more efficient library methods. Codeflash
|
||||
does not modify the system architecture of your code, but it tries to find the most efficient implementation of your current architecture.
|
||||
|
||||
### How to use Codeflash
|
||||
|
||||
<CardGroup cols={1}>
|
||||
<Card title="Optimize a Single Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
|
||||
Target and optimize individual Python functions for maximum performance gains.
|
||||
```bash
|
||||
codeflash --file path.py --function my_function
|
||||
```
|
||||
</Card>
|
||||
|
||||
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
|
||||
Automatically find optimizations for Pull Requests with GitHub Actions integration.
|
||||
```bash
|
||||
codeflash init-actions
|
||||
```
|
||||
</Card>
|
||||
|
||||
<Card title="Optimize Workflows with Tracing" icon="route" href="/optimizing-with-codeflash/trace-and-optimize">
|
||||
End-to-end optimization of entire Python workflows with execution tracing.
|
||||
```bash
|
||||
codeflash optimize myscript.py
|
||||
```
|
||||
</Card>
|
||||
|
||||
<Card title="Optimize Your Entire Codebase" icon="globe" href="/optimizing-with-codeflash/codeflash-all">
|
||||
Automatically optimize all functions in your project with comprehensive analysis.
|
||||
```bash
|
||||
codeflash --all
|
||||
```
|
||||
</Card>
|
||||
|
||||
|
||||
</CardGroup>
|
||||
|
||||
### How does Codeflash verify correctness?
|
||||
|
||||
Codeflash verifies the correctness of the optimizations it finds by generating and running new regression tests, as well as any existing tests you may already have. Codeflash tries to ensure that your
|
||||
code behaves the same way before and after the optimization.
|
||||
This offers high confidence that the behavior of your code remains unchanged.
|
||||
|
||||
### Continuous Optimization
|
||||
|
||||
Because Codeflash is an automated process, the main way to use it is by installing it as a GitHub action and have it optimize the new code on every pull request.
|
||||
When Codeflash finds an optimization, it will ask you to review it. It will write a detailed explanation of the changes it made, and include all relevant info like % speed increase and proofs of correctness.
|
||||
|
||||
This is a great way to ensure that your code, your team's code and your AI Agent's code are optimized for performance before it causes a performance regression. We call this *Continuous Optimization*.
|
||||
|
||||
## Questions or Feedback?
|
||||
|
||||
Your feedback will help us make codeflash better, faster. If you have any questions or feedback, use the Intercom button in the lower right, join our [Discord](https://www.codeflash.ai/discord), or drop us a note at [contact@codeflash.ai](mailto:contact@codeflash.ai) - we read every message!
|
||||
|
|
@ -1,17 +1,21 @@
|
|||
---
|
||||
sidebar_position: 5
|
||||
title: "Optimize Performance Benchmarks with every Pull Request"
|
||||
description: "Configure and use pytest-benchmark integration for performance-critical code optimization"
|
||||
icon: "chart-line"
|
||||
sidebarTitle: Setup Benchmarks to Optimize
|
||||
keywords: ["benchmarks", "CI", "pytest-benchmark", "performance testing", "github actions", "benchmark mode"]
|
||||
---
|
||||
# Using Benchmarks in CI
|
||||
<Info>
|
||||
**Performance-critical optimization** - Define benchmarks for your most important code sections and let Codeflash optimize and measure the real-world impact of every optimization on your performance metrics.
|
||||
</Info>
|
||||
|
||||
Codeflash is able to determine the impact of an optimization on predefined benchmarks, when used in benchmark mode.
|
||||
|
||||
Benchmark mode is an easy way for users to define workflows that are performance-critical and need to be optimized.
|
||||
For example, if a user has an important function that requires minimal latency, the user can define a benchmark for that function.
|
||||
Codeflash will then calculate the impact (if any) of any optimization on the performance of that function.
|
||||
Benchmark mode is an easy way to define workflows that are performance-critical and need to be optimized and run fast.
|
||||
Codeflash will run the benchmark, understand how the current code change in the Pull Request is affecting the benchmark.
|
||||
It will then try to optimize the new code for the benchmark and calculate the impact of any optimization on the speed of that benchmark.
|
||||
|
||||
## Using Codeflash in Benchmark Mode
|
||||
|
||||
1. **Create a benchmarks root**
|
||||
1. **Create a benchmarks root:**
|
||||
|
||||
Create a directory for benchmarks if it does not already exist.
|
||||
|
||||
|
|
@ -27,7 +31,7 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
|
|||
formatter-cmds = ["disabled"]
|
||||
```
|
||||
|
||||
2. **Define your benchmarks**
|
||||
2. **Define your benchmarks:**
|
||||
|
||||
Currently, Codeflash only supports benchmarks written as pytest-benchmarks. Check out the [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/stable/index.html) documentation for more information on syntax.
|
||||
|
||||
|
|
@ -46,7 +50,7 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
|
|||
The pytest-benchmark format is simply used as an interface. The plugin is actually not used - Codeflash will run these benchmarks with its own pytest plugin
|
||||
|
||||
|
||||
3. **Run Codeflash**
|
||||
3. **Run and Test Codeflash:**
|
||||
|
||||
Run Codeflash with the `--benchmark` flag. Note that benchmark mode cannot be used with `--all`.
|
||||
|
||||
|
|
@ -61,13 +65,15 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
|
|||
```
|
||||
|
||||
|
||||
4. **Run Codeflash in CI**
|
||||
4. **Run Codeflash :**
|
||||
|
||||
Benchmark mode is best used together with Codeflash as a Github Action. This way, with every PR, you will know the impact of Codeflash's optimizations on your benchmarks.
|
||||
Benchmark mode is best used together with Codeflash as a GitHub Action. This way,
|
||||
Codeflash will trace through your benchmark and optimize the functions modified in your Pull Request to speed up the benchmark.
|
||||
It will also report the impact of Codeflash's optimizations on your benchmarks.
|
||||
|
||||
Use `codeflash init` for an easy way to set up Codeflash as a Github Action (with the option to enable benchmark mode).
|
||||
Use `codeflash init` for an easy way to set up Codeflash as a GitHub Action.
|
||||
|
||||
Otherwise, you can run the following command in your Codeflash GitHub Action:
|
||||
After that, you can add the `--benchmark` argument to codeflash to enable benchmarks optimization.
|
||||
|
||||
```bash
|
||||
codeflash --benchmark
|
||||
|
|
@ -80,7 +86,7 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
|
|||
1. Codeflash identifies benchmarks in the benchmarks-root directory.
|
||||
|
||||
|
||||
2. The benchmarks are run so that runtime statistics and information can be recorded.
|
||||
2. The benchmarks are run so that runtime statistics and inputs can be recorded.
|
||||
|
||||
|
||||
3. Replay tests are generated so the performance of optimization candidates on the exact inputs used in the benchmarks can be measured.
|
||||
|
|
@ -93,5 +99,3 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
|
|||
|
||||
|
||||
Using Codeflash with benchmarks is a great way to find optimizations that really matter.
|
||||
|
||||
Codeflash is actively working on this feature and will be adding new capabilities in the near future!
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
---
|
||||
sidebar_position: 2
|
||||
title: "Optimize Your Entire Codebase"
|
||||
description: "Automatically optimize all codepaths in your project with Codeflash's comprehensive analysis"
|
||||
icon: "database"
|
||||
sidebarTitle: "Optimize Entire Codebase"
|
||||
keywords: ["codebase optimization", "all functions", "batch optimization", "github app", "checkpoint", "recovery"]
|
||||
---
|
||||
|
||||
# Optimize your entire codebase
|
||||
|
|
@ -15,13 +19,34 @@ codeflash --all
|
|||
|
||||
This requires the Codeflash GitHub App to be installed in your repository.
|
||||
|
||||
This is a powerful feature that can help you optimize your entire codebase in one go.
|
||||
This is a powerful feature that can help you optimize your entire codebase in one go. It also discovers and runs any unit tests covering the function under optimization.
|
||||
|
||||
Since it runs on all the functions in your codebase, it can take some time to complete, please be patient.
|
||||
As this runs you will see Codeflash opening pull requests for each function it successfully optimizes.
|
||||
|
||||
If you only want to optimize a subdirectory you can run:
|
||||
```bash
|
||||
codeflash --all path/to/dir
|
||||
```
|
||||
|
||||
<Tip>
|
||||
If your project has a good number of unit tests, we can trace those to achieve higher quality results.
|
||||
The following approach is recommended instead:
|
||||
```bash
|
||||
codeflash optimize --trace-only -m pytest tests/ ; codeflash --all
|
||||
```
|
||||
This will run your test suite, trace all the code covered by your tests, ensuring higher correctness guarantees
|
||||
and better performance benchmarking, and help create optimizations for code where the LLMs struggle to generate and run tests.
|
||||
|
||||
Even though `codeflash --all` discovers any existing unit tests. It currently can only discover any test that directly calls the
|
||||
function under optimization. Tracing all the tests helps ensure correctness for code that may be indirectly called by your tests.
|
||||
|
||||
</Tip>
|
||||
## Important considerations
|
||||
- **Dedicated Optimization Machine:** Optimizing the entire codebase may require considerable time—up to one day. It's recommended to allocate a dedicated machine specifically for this long-running optimization task.
|
||||
|
||||
- **Minimize Background Processes:** To achieve optimal results, avoid running other processes on the optimization machine. Additional processes can introduce noise into Codeflash's runtime measurements, reducing the quality of the optimizations. Although Codeflash tolerates some runtime fluctuations, minimizing noise ensures the highest optimization quality.
|
||||
|
||||
- **Checkpoint and Recovery:** Codeflash automatically creates checkpoints as it identifies optimizations. If the optimization process is interrupted or encounters issues, you can resume the process by re-running `codeflash --all`. The command will prompt you to continue from the most recent checkpoint.
|
||||
|
||||
|
||||
172
docs/optimizing-with-codeflash/codeflash-github-actions.mdx
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
---
|
||||
title: "Auto Optimize Pull Requests"
|
||||
description: "Automatically optimize new code in pull requests with Codeflash GitHub Actions workflow"
|
||||
icon: "github"
|
||||
---
|
||||
|
||||
|
||||
Optimizing new code in Pull Requests is the best way to ensure that all code you and your team ship is performant
|
||||
in the future. Automating optimization in the Pull Request stage how most teams use Codeflash, to
|
||||
continuously find optimizations for their new code.
|
||||
|
||||
To scan new code for performance optimizations, Codeflash uses a GitHub Action workflow which runs
|
||||
the Codeflash optimization logic on the new code in every pull request.
|
||||
If the action workflow finds an optimization, it communicates with the Codeflash GitHub
|
||||
App and asks it to suggest new changes to the pull request.
|
||||
|
||||
This is the most useful way of using Codeflash, where you set it up once and all your new code gets optimized.
|
||||
So setting this up is highly recommended.
|
||||
|
||||
## Pull Request Optimization 30 seconds demo
|
||||
<iframe width="640" height="400" src="https://www.youtube.com/embed/nqa-uewizkU?si=H1wb1RvPp-JqvKPh" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
<Warning>
|
||||
**Before you begin, make sure you have:**
|
||||
|
||||
✅ A Codeflash API key from the [Codeflash Web App](https://app.codeflash.ai/)
|
||||
|
||||
✅ Completed [local installation](/getting-started/local-installation) with `codeflash init`
|
||||
|
||||
✅ A Python project with a configured `pyproject.toml` file
|
||||
</Warning>
|
||||
|
||||
## Setup Options
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Automated Setup (Recommended)">
|
||||
<Steps>
|
||||
<Step title="Run the Setup Command">
|
||||
```bash
|
||||
codeflash init-actions
|
||||
```
|
||||
This command will automatically create the GitHub Actions workflow file and guide you through the setup process.
|
||||
|
||||
Alternatively running `codeflash init` also asks to setup the github actions.
|
||||
</Step>
|
||||
|
||||
<Step title="Customize and Test Your Setup">
|
||||
Open a new pull request to your GitHub project. You'll see:
|
||||
- ✅ A new Codeflash workflow running in GitHub Actions
|
||||
- 🤖 The codeflash-ai bot commenting with optimization suggestions (if any are found)
|
||||
|
||||
Ensure that your Python environment installation works correctly and codeflash is able to run.
|
||||
</Step>
|
||||
</Steps>
|
||||
</Tab>
|
||||
|
||||
<Tab title="Manual Setup">
|
||||
|
||||
<Steps>
|
||||
<Step title="Create Workflow File">
|
||||
Create `.github/workflows/codeflash-optimize.yaml` in your repository:
|
||||
|
||||
|
||||
```yaml title=".github/workflows/codeflash-optimize.yaml"
|
||||
name: Codeflash
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
optimize:
|
||||
name: Optimize new code in this PR
|
||||
if: ${{ github.actor != 'codeflash-ai[bot]' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# TODO: Replace the following with your project's Python installation method
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
# TODO: Replace the following with your project's dependency installation method
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# TODO: Replace the following with your project setup method
|
||||
pip install -r requirements.txt
|
||||
pip install codeflash
|
||||
- name: Run Codeflash to optimize code
|
||||
id: optimize_code
|
||||
run: |
|
||||
codeflash
|
||||
```
|
||||
|
||||
<Warning>
|
||||
**Replace the TODOs** in the workflow file above with your project's specific setup commands.
|
||||
</Warning>
|
||||
</Step>
|
||||
|
||||
<Step title="Choose Your Package Manager">
|
||||
Customize the dependency installation based on your Python package manager:
|
||||
|
||||
<CodeGroup>
|
||||
```yaml Poetry
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install poetry
|
||||
poetry install --with dev
|
||||
- name: Run Codeflash to optimize code
|
||||
run: |
|
||||
poetry env use python
|
||||
poetry run codeflash
|
||||
```
|
||||
|
||||
```yaml uv
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
- run: uv sync --group=dev
|
||||
- name: Run Codeflash to optimize code
|
||||
run: uv run codeflash
|
||||
```
|
||||
|
||||
```yaml pip
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install codeflash
|
||||
- name: Run Codeflash to optimize code
|
||||
run: codeflash
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
<Step title="Add Repository Secret">
|
||||
1. Go to your GitHub repository settings
|
||||
2. Navigate to **Secrets and Variables** → **Actions**
|
||||
3. Click **New repository secret**
|
||||
4. Add:
|
||||
- **Name**: `CODEFLASH_API_KEY`
|
||||
- **Value**: Your API key from [app.codeflash.ai/app/apikeys](https://app.codeflash.ai/app/apikeys)
|
||||
|
||||
<Tip>
|
||||
**Security Note**: Never commit your API key directly to your code. Always use GitHub repository secrets.
|
||||
</Tip>
|
||||
</Step>
|
||||
</Steps>
|
||||
</Tab>
|
||||
</Tabs>
|
||||
## How the Pull Request Optimization Suggestion looks
|
||||
|
||||
Codeflash creates a new dependent Pull Request for you to review with the reported speedups, helpful explanation for the optimization
|
||||
and the proof of correctness. The pull request has the code change for you to review and accept.
|
||||
|
||||

|
||||
|
||||
|
||||
Sometimes it also makes an inline suggestion with the optimization.
|
||||
|
||||

|
||||
|
||||
We hope you enjoy the performance unlock the Pull Request optimization enables.
|
||||
|
|
@ -1,7 +1,10 @@
|
|||
---
|
||||
sidebar_position: 1
|
||||
title: "Optimize a Single Function"
|
||||
description: "Target and optimize individual Python functions for maximum performance gains"
|
||||
icon: "bullseye"
|
||||
sidebarTitle: "Optimize Single Function"
|
||||
keywords: ["function optimization", "single function", "class methods", "performance", "targeted optimization"]
|
||||
---
|
||||
# Optimize a function
|
||||
|
||||
Codeflash is essentially a function optimizer. When asked to optimize a function, Codeflash will analyze the function,
|
||||
any helper functions it calls, and the imports it uses. It will then generate multiple new versions of the function and its helper functions, that are
|
||||
61
docs/optimizing-with-codeflash/review-optimizations.mdx
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
---
|
||||
title: "Staging Review for Optimizations"
|
||||
description: "Review and manage optimizations before creating pull requests with staging mode"
|
||||
icon: "layer-group"
|
||||
sidebarTitle: "Review Optimizations"
|
||||
keywords:
|
||||
[
|
||||
"staging review",
|
||||
"optimization preview",
|
||||
"pro feature",
|
||||
"batch review",
|
||||
"pull request management",
|
||||
]
|
||||
---
|
||||
|
||||
# Staging Review for Optimizations
|
||||
|
||||
<Note>
|
||||
This is a Pro feature available exclusively to Codeflash Pro users. [Upgrade to
|
||||
Pro](https://app.codeflash.ai/billing) to access it.
|
||||
</Note>
|
||||
Staging Review allows you to preview and evaluate all optimizations before creating pull requests. This
|
||||
feature provides a centralized review interface where you can examine proposed changes and selectively
|
||||
create pull requests for approved optimizations.
|
||||
|
||||
## Benefits of Staging Review
|
||||
|
||||
- **Preview without PRs:** Review all optimization suggestions without cluttering your repository with multiple pull requests
|
||||
- **Batch review:** Examine all optimizations in one centralized location
|
||||
- **Selective PR creation:** Choose which optimizations to convert into pull requests
|
||||
- **Reduced noise:** Keep your repository's PR list clean while evaluating changes
|
||||
|
||||
## Using Staging Review
|
||||
|
||||
To optimize your codebase with staging review enabled, run:
|
||||
|
||||
```bash
|
||||
codeflash --all --staging-review
|
||||
```
|
||||
|
||||
This command will:
|
||||
|
||||
|
||||
1. Start to optimize your project, and if it finds any optimizations, it will save them to [Review Optimizations Page](https://app.codeflash.ai/review-optimizations) instead of creating PRs immediately
|
||||
|
||||

|
||||
|
||||
2. Provide you with a staging interface to review all proposed changes
|
||||
|
||||

|
||||
|
||||
## Managing Staged Optimizations
|
||||
|
||||
Once optimizations are staged, you can:
|
||||
- Review each optimization individually
|
||||
- Compare original and optimized code side-by-side
|
||||
- Edit the code of staged optimizations
|
||||
- Create pull requests for selected optimizations directly from the staging interface
|
||||
|
||||

|
||||
|
||||
|
|
@ -1,38 +1,44 @@
|
|||
---
|
||||
sidebar_position: 4
|
||||
title: "Trace & Optimize E2E Workflows"
|
||||
description: "End-to-end optimization of entire Python workflows with execution tracing"
|
||||
icon: "route"
|
||||
sidebarTitle: "Optimize E2E Workflows"
|
||||
keywords: ["tracing", "workflow optimization", "replay tests", "end-to-end", "script optimization", "context manager"]
|
||||
---
|
||||
# Optimize Workflows End-to-End
|
||||
|
||||
Codeflash supports optimizing an entire Python script end-to-end by tracing the script's execution and generating Replay Tests. Tracing follows the execution of a script, profiles it and captures inputs to all called functions, allowing them to be replayed during optimization. Codeflash uses these Replay Tests to optimize all functions called in the script, starting from the most important ones.
|
||||
Codeflash supports optimizing an entire Python script end-to-end by tracing the script's execution and generating Replay Tests.
|
||||
Tracing follows the execution of a script, profiles it and captures inputs to all functions it called, allowing them to be replayed during optimization.
|
||||
Codeflash uses these Replay Tests to optimize all functions called in the script, starting from the most important ones.
|
||||
|
||||
To optimize a script, `python myscript.py`, replace `python` with `codeflash optimize` and run the following command:
|
||||
To optimize a script, `python myscript.py`, simply replace `python` with `codeflash optimize` and run the following command:
|
||||
|
||||
```bash
|
||||
codeflash optimize myscript.py
|
||||
```
|
||||
|
||||
To optimize code called by pytest tests that you could normally run like `python -m pytest tests/`, use this command:
|
||||
You can also optimize code called by pytest tests that you could normally run like `python -m pytest tests/`, this provides for a good workload to optimize. Run this command:
|
||||
|
||||
```bash
|
||||
codeflash optimize -m pytest tests/
|
||||
```
|
||||
|
||||
This powerful command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
|
||||
The powerful `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
|
||||
|
||||
The generated replay tests and the trace file are for the immediate optimization use, don't add them to git.
|
||||
|
||||
## Codeflash optimize demo (1 min)
|
||||
## Codeflash optimize 1 min demo
|
||||
|
||||
<iframe width="750" height="460" src="https://www.youtube.com/embed/_nwliGzRIug" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>
|
||||
<iframe width="640" height="400" src="https://www.youtube.com/embed/_nwliGzRIug" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>
|
||||
|
||||
|
||||
## What is the codeflash optimize command?
|
||||
|
||||
`codeflash optimize` tries to do everything that an expert engineer would do while optimizing a workflow. It profiles your code, traces the execution of your workflow and generates a set of test cases that are derived from how your code is actually run.
|
||||
Codeflash Tracer works by recording the inputs of your functions as they are called in your codebase. These inputs are then used to generate test cases that are representative of the real-world usage of your functions.
|
||||
Codeflash Tracer works by recording the inputs of your functions as they are called in your codebase, and generating
|
||||
regression tests with those inputs.
|
||||
We call these generated test cases "Replay Tests" because they replay the inputs that were recorded during the tracing phase.
|
||||
These replay tests are representative of the real-world usage of your functions.
|
||||
|
||||
Then, Codeflash Optimizer can use these replay tests to verify correctness and calculate accurate performance gains for the optimized functions.
|
||||
Using Replay Tests, Codeflash can verify that the optimized functions produce the same output as the original function and also measure the performance gains of the optimized function on the real-world inputs.
|
||||
This way you can be *sure* that the optimized function causes no changes of behavior for the traced workflow and also, that it is faster than the original function. To get more confidence on the correctness of the code, we also generate several LLM generated test cases and discover any existing unit cases you may have.
|
||||
|
||||
|
|
@ -54,15 +60,16 @@ Codeflash script optimizer can be used in three ways:
|
|||
codeflash optimize path/to/your/file.py --your_options
|
||||
```
|
||||
|
||||
The above command should suffice in most situations. You can add a argument like `codeflash optimize -o trace_file_path.trace` if you want to customize the trace file location. Otherwise, it defaults to `codeflash.trace` in the current working directory.
|
||||
The above command should suffice in most situations.
|
||||
To customize the trace file location you can specify it like `codeflash optimize -o trace_file_path.trace`. Otherwise, it defaults to `codeflash.trace` in the current working directory.
|
||||
|
||||
2. **Trace and optimize as two separate steps**
|
||||
|
||||
If you want more control over the tracing and optimization process. You can trace first and then optimize with the replay tests later. Each replay test is associated with a trace file.
|
||||
|
||||
To first create just the trace file, run
|
||||
To create just the trace file first, run
|
||||
|
||||
```python
|
||||
```bash
|
||||
codeflash optimize -o trace_file.trace --trace-only path/to/your/file.py --your_options
|
||||
```
|
||||
|
||||
|
|
@ -76,7 +83,7 @@ Codeflash script optimizer can be used in three ways:
|
|||
- `--tracer-timeout`: The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows.
|
||||
3. **As a Context Manager -**
|
||||
|
||||
To trace only very specific sections of your codeflash, You can also use the Codeflash Tracer as a context manager.
|
||||
To trace only specific sections of your code, You can also use the Codeflash Tracer as a context manager.
|
||||
You can wrap the code you want to trace in a `with` statement as follows -
|
||||
|
||||
```python
|
||||
|
|
@ -86,7 +93,7 @@ Codeflash script optimizer can be used in three ways:
|
|||
model.predict() # Your code here
|
||||
```
|
||||
|
||||
This is much faster than tracing the whole script. Sometimes, if tracing the whole script fails, then the Context Manager can also be used to trace the code sections.
|
||||
This is much faster than tracing the whole script. It can also help if tracing the whole script fails.
|
||||
|
||||
After this finishes, you can optimize using the generated replay tests.
|
||||
|
||||
|
|
@ -94,7 +101,7 @@ Codeflash script optimizer can be used in three ways:
|
|||
codeflash --replay-test /path/to/test_replay_test_0.py
|
||||
```
|
||||
|
||||
More Options for the Tracer:
|
||||
More Options for the Tracer Context Manager:
|
||||
|
||||
- `disable`: If set to `True`, the tracer will not trace the code. Default is `False`.
|
||||
- `max_function_count`: The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
{
|
||||
"name": "cf-docs",
|
||||
"version": "0.0.0",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"docusaurus": "docusaurus",
|
||||
"start": "docusaurus start",
|
||||
"build": "docusaurus build",
|
||||
"swizzle": "docusaurus swizzle",
|
||||
"deploy": "docusaurus deploy",
|
||||
"clear": "docusaurus clear",
|
||||
"serve": "docusaurus serve",
|
||||
"write-translations": "docusaurus write-translations",
|
||||
"write-heading-ids": "docusaurus write-heading-ids",
|
||||
"typecheck": "tsc"
|
||||
},
|
||||
"dependencies": {
|
||||
"@docusaurus/core": "^3.6.3",
|
||||
"@docusaurus/preset-classic": "^3.6.3",
|
||||
"@mdx-js/react": "^3.0.0",
|
||||
"clsx": "^2.0.0",
|
||||
"prism-react-renderer": "^2.3.0",
|
||||
"react": "^18.0.0",
|
||||
"react-dom": "^18.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@docusaurus/module-type-aliases": "^3.6.3",
|
||||
"@docusaurus/tsconfig": "^3.6.3",
|
||||
"@docusaurus/types": "^3.6.3",
|
||||
"typescript": "~5.2.2"
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
">0.5%",
|
||||
"not dead",
|
||||
"not op_mini all"
|
||||
],
|
||||
"development": [
|
||||
"last 3 chrome version",
|
||||
"last 3 firefox version",
|
||||
"last 5 safari version"
|
||||
]
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,31 +0,0 @@
|
|||
import type {SidebarsConfig} from '@docusaurus/plugin-content-docs';
|
||||
|
||||
/**
|
||||
* Creating a sidebar enables you to:
|
||||
- create an ordered group of docs
|
||||
- render a sidebar for each doc of that group
|
||||
- provide next/previous navigation
|
||||
|
||||
The sidebars can be generated from the filesystem, or explicitly defined here.
|
||||
|
||||
Create as many sidebars as you want.
|
||||
*/
|
||||
const sidebars: SidebarsConfig = {
|
||||
// By default, Docusaurus generates a sidebar from the docs folder structure
|
||||
tutorialSidebar: [{type: 'autogenerated', dirName: '.'}],
|
||||
|
||||
// But you can create a sidebar manually
|
||||
/*
|
||||
tutorialSidebar: [
|
||||
'intro',
|
||||
'hello',
|
||||
{
|
||||
type: 'category',
|
||||
label: 'Tutorial',
|
||||
items: ['tutorial-basics/create-a-document'],
|
||||
},
|
||||
],
|
||||
*/
|
||||
};
|
||||
|
||||
export default sidebars;
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
import clsx from 'clsx';
|
||||
import Heading from '@theme/Heading';
|
||||
import styles from './styles.module.css';
|
||||
|
||||
type FeatureItem = {
|
||||
title: string;
|
||||
Svg: React.ComponentType<React.ComponentProps<'svg'>>;
|
||||
description: JSX.Element;
|
||||
};
|
||||
|
||||
const FeatureList: FeatureItem[] = [
|
||||
{
|
||||
title: 'Easy to Use',
|
||||
Svg: require('@site/static/img/undraw_docusaurus_mountain.svg').default,
|
||||
description: (
|
||||
<>
|
||||
Docusaurus was designed from the ground up to be easily installed and
|
||||
used to get your website up and running quickly.
|
||||
</>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Focus on What Matters',
|
||||
Svg: require('@site/static/img/undraw_docusaurus_tree.svg').default,
|
||||
description: (
|
||||
<>
|
||||
Docusaurus lets you focus on your docs, and we'll do the chores. Go
|
||||
ahead and move your docs into the <code>docs</code> directory.
|
||||
</>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: 'Powered by React',
|
||||
Svg: require('@site/static/img/undraw_docusaurus_react.svg').default,
|
||||
description: (
|
||||
<>
|
||||
Extend or customize your website layout by reusing React. Docusaurus can
|
||||
be extended while reusing the same header and footer.
|
||||
</>
|
||||
),
|
||||
},
|
||||
];
|
||||
|
||||
function Feature({title, Svg, description}: FeatureItem) {
|
||||
return (
|
||||
<div className={clsx('col col--4')}>
|
||||
<div className="text--center">
|
||||
<Svg className={styles.featureSvg} role="img" />
|
||||
</div>
|
||||
<div className="text--center padding-horiz--md">
|
||||
<Heading as="h3">{title}</Heading>
|
||||
<p>{description}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function HomepageFeatures(): JSX.Element {
|
||||
return (
|
||||
<section className={styles.features}>
|
||||
<div className="container">
|
||||
<div className="row">
|
||||
{FeatureList.map((props, idx) => (
|
||||
<Feature key={idx} {...props} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
}
|
||||