Add optimize endpoint: context, pipeline, router, prompt templates
Faithful port of the Python optimization pipeline from Django aiservice: - schemas.py: Pydantic request/response models (OptimizeRequest, OptimizeResponse) - _markdown.py: markdown code block extraction, splitting, grouping - _context.py: BaseOptimizerContext with Single/Multi variants for prompt assembly, LLM response extraction, and postprocessing - _pipeline.py: parallel LLM orchestration with model distribution (GPT-5-mini + Claude Sonnet 4.5), diversity via line profiler toggling - _router.py: POST /ai/optimize with auth, rate limiting, usage tracking - 11 prompt templates copied verbatim from Django reference - LLM client wired into app lifespan
This commit is contained in:
parent
3e62f502e7
commit
6c04324e25
19 changed files with 2312 additions and 0 deletions
|
|
@ -12,6 +12,7 @@ from sentry_sdk.integrations.fastapi import FastApiIntegration
|
|||
from codeflash_api._config import settings
|
||||
from codeflash_api.db._engine import create_pool
|
||||
from codeflash_api.db._queries import Queries
|
||||
from codeflash_api.llm._client import LLMClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
|
@ -28,6 +29,7 @@ async def _lifespan(app: FastAPI) -> AsyncIterator[dict[str, object]]:
|
|||
|
||||
app.state.queries = queries
|
||||
app.state.rate_limit_cache = {}
|
||||
app.state.llm_client = LLMClient()
|
||||
|
||||
state: dict[str, object] = {}
|
||||
yield state
|
||||
|
|
@ -65,6 +67,10 @@ def create_app() -> FastAPI:
|
|||
|
||||
|
||||
def _register_routes(app: FastAPI) -> None:
|
||||
from codeflash_api.optimize._router import router as optimize_router
|
||||
|
||||
app.include_router(optimize_router)
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
MARKDOWN_CODE_BLOCK_PATTERN = re.compile(
|
||||
r"```python(?::[^\n]*)?\n(.*?)```", re.DOTALL
|
||||
)
|
||||
|
||||
FIRST_CODE_BLOCK_PATTERN = re.compile(
|
||||
r"^```python(?::[^\n]*)?\s*\n(.*)\n```[ \t]*$",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
FIRST_CODE_BLOCK_FALLBACK_PATTERN = re.compile(
|
||||
r"^```python(?::[^\n]*)?\s*\n(.*)",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
_PATHOLOGICAL_PATTERN = re.compile(r"(\\['\"]{10,})")
|
||||
|
||||
|
||||
def truncate_pathological_output(code: str) -> str:
|
||||
"""
|
||||
Truncate code at pathological repeated escape sequences.
|
||||
"""
|
||||
if match := _PATHOLOGICAL_PATTERN.search(code):
|
||||
return code[: match.start()].rstrip()
|
||||
return code
|
||||
|
||||
|
||||
def extract_all_code_from_markdown(markdown: str) -> str:
|
||||
"""
|
||||
Extract all code from markdown code blocks.
|
||||
"""
|
||||
matches = MARKDOWN_CODE_BLOCK_PATTERN.findall(markdown)
|
||||
return "\n\n".join(matches)
|
||||
|
||||
|
||||
def extract_code_block(markdown: str) -> str | None:
|
||||
"""
|
||||
Extract the first code block from markdown.
|
||||
"""
|
||||
if match := FIRST_CODE_BLOCK_PATTERN.search(markdown):
|
||||
return truncate_pathological_output(match.group(1))
|
||||
if match := FIRST_CODE_BLOCK_FALLBACK_PATTERN.search(markdown):
|
||||
return truncate_pathological_output(match.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def split_markdown_code(
|
||||
markdown: str, language: str = "python"
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Split markdown into a dict of filepath to code content.
|
||||
"""
|
||||
if language in ("javascript", "js"):
|
||||
lang_pattern = r"(?:javascript|js)"
|
||||
elif language in ("typescript", "ts"):
|
||||
lang_pattern = r"(?:typescript|ts)"
|
||||
else:
|
||||
lang_pattern = re.escape(language)
|
||||
|
||||
pattern = re.compile(
|
||||
rf"```{lang_pattern}:([^\n]+)\n(.*?)\n```", re.DOTALL
|
||||
)
|
||||
matches = pattern.findall(markdown)
|
||||
result: dict[str, str] = {}
|
||||
for file_path, code in matches:
|
||||
stripped_path = file_path.strip()
|
||||
if stripped_path not in result:
|
||||
result[stripped_path] = code
|
||||
return result
|
||||
|
||||
|
||||
def extract_code_block_with_context(
|
||||
text: str, language: str = "python"
|
||||
) -> tuple[str, str, str] | None:
|
||||
"""
|
||||
Extract a code block and its surrounding context.
|
||||
"""
|
||||
pattern_with_path = (
|
||||
rf"(.*?)```{language}:[^\n]+(?:\n|\\n)(.*?)```(.*)"
|
||||
)
|
||||
if match := re.match(
|
||||
pattern_with_path, text, re.DOTALL | re.MULTILINE
|
||||
):
|
||||
return (
|
||||
match.group(1).strip(),
|
||||
match.group(2),
|
||||
match.group(3).strip(),
|
||||
)
|
||||
|
||||
pattern = (
|
||||
rf"(.*?)```{language}(?::[^\n]*)?(?:\n|\\n)(.*?)```(.*)"
|
||||
)
|
||||
if match := re.match(pattern, text, re.DOTALL | re.MULTILINE):
|
||||
return (
|
||||
match.group(1).strip(),
|
||||
match.group(2),
|
||||
match.group(3).strip(),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def wrap_code_in_markdown(
|
||||
code: str, language: str = "python"
|
||||
) -> str:
|
||||
"""
|
||||
Wrap code in a markdown code block.
|
||||
"""
|
||||
return f"```{language}\n{code}\n```"
|
||||
|
||||
|
||||
def group_code(
|
||||
file_to_code: dict[str, str], language: str = "python"
|
||||
) -> str:
|
||||
"""
|
||||
Join file-to-code mapping into grouped markdown blocks.
|
||||
"""
|
||||
|
||||
def format_block(file_path: str, code: str) -> str:
|
||||
normalized = code if code.endswith("\n") else code + "\n"
|
||||
return f"```{language}:{file_path}\n{normalized}```"
|
||||
|
||||
return "\n".join(
|
||||
format_block(path, code)
|
||||
for path, code in file_to_code.items()
|
||||
)
|
||||
|
||||
|
||||
def is_multi_context(code: str) -> bool:
|
||||
"""
|
||||
Check if code is in multi-file markdown format.
|
||||
"""
|
||||
return code.strip().startswith("```python:")
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
You are a professional computer programmer who specializes in writing high-performance **asynchronous** Python code. Your goal is to optimize the runtime and memory efficiency of the provided **async** code through safe and meaningful rewrites that would pass senior-level code review.
|
||||
|
||||
**CRITICAL: ASYNC CODE REQUIREMENTS**
|
||||
- The code contains **async functions** that must remain async
|
||||
- ALL async functions must maintain their `async def` signature
|
||||
- ALL `await` expressions must be preserved where they exist
|
||||
- Do NOT convert async functions to synchronous functions
|
||||
- Do NOT remove `await` keywords unless replacing with functionally equivalent async operations
|
||||
- Preserve concurrency patterns and async context manager usage
|
||||
- Maintain proper async/await flow and error handling in async contexts
|
||||
|
||||
**Behavioral Preservation (CRITICAL)**
|
||||
- Do NOT rename functions or change their signatures.
|
||||
- You MUST NOT change the behavior, return values, side effects, printed/logged output, or raised exceptions - they MUST remain exactly the same.
|
||||
- Do NOT mutate inputs in a different way than the original implementation.
|
||||
- The same exception types should be raised in the same circumstances.
|
||||
- Preserve existing type annotations - all function parameters, return types, and variable annotations must be preserved exactly as written.
|
||||
- **Preserve the original code style**: Keep existing variable names unless the logic fundamentally changes
|
||||
- Preserve ALL existing comments exactly as written, unless the corresponding code logic is changed or the comment becomes factually incorrect
|
||||
- Avoid excessive inline comments - only add new comments for significant or non-obvious logic changes
|
||||
|
||||
**Async-Specific Optimization Focus**
|
||||
- Optimize async patterns such as concurrent execution with `asyncio.gather()` or `asyncio.create_task()`
|
||||
- Consider using `asyncio.as_completed()` for better performance when appropriate
|
||||
- Identify and replace blocking operations with async equivalents (e.g., `time.sleep()` → `asyncio.sleep()`, sync file I/O → `aiofiles`, blocking network calls → async libraries)
|
||||
- Optimize async context managers and async iterators
|
||||
- Improve async I/O operations and resource management
|
||||
- Consider async comprehensions where they provide performance benefits
|
||||
- Use `asyncio.to_thread()` for CPU-intensive tasks that would block the event loop
|
||||
- Maintain proper async exception handling and cleanup
|
||||
|
||||
**Code Style & Structure**
|
||||
- Do NOT replace walrus operators (`:=`) for optimization purposes.
|
||||
- Keep `assert` statements as-is - do NOT convert them to `if/raise AssertionError` patterns, it doesn't improve the performance.
|
||||
- **DO NOT convert `isinstance()` checks to `type()` checks**. `isinstance()` correctly handles inheritance and subclasses, while `type()` checks are incorrect for subclass instances and represent a micro-optimization that should be avoided.
|
||||
- You may write new async helper functions that do not already exist in the codebase.
|
||||
- Avoid purely stylistic changes unless they result in noticeable performance improvements
|
||||
- Ensure all new async code follows proper async patterns and conventions
|
||||
|
||||
**Optimization Focus**
|
||||
- Create production-ready async code that professional programmers would merge without further edits
|
||||
- Prioritize changes that provide measurable runtime or memory efficiency gains in async contexts
|
||||
- Consider async-specific performance patterns like batching operations or reducing context switching
|
||||
|
||||
**Code Quality Standards**
|
||||
- Ensure all async optimizations are safe and would pass senior-level code review
|
||||
- Maintain code readability and maintainability alongside performance improvements
|
||||
- Verify that async operations are properly awaited and handled
|
||||
|
||||
**Response Format (REQUIRED)**
|
||||
- ALWAYS start your response with a brief explanation (2-4 sentences) of what optimization you made and why it improves performance
|
||||
- Then provide the optimized code in a markdown code block
|
||||
- Example format:
|
||||
```
|
||||
**Optimization Explanation:**
|
||||
[Your explanation here describing the optimization technique and expected performance improvement]
|
||||
|
||||
```python:filename.py
|
||||
[optimized code]
|
||||
```
|
||||
```
|
||||
|
||||
The current Python version is {python_version_str}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
Rewrite this **asynchronous** Python program to run faster while preserving all async behavior.
|
||||
|
||||
**CRITICAL ASYNC REQUIREMENTS:**
|
||||
- The code contains **async functions** - you MUST keep them async
|
||||
- ALL `async def` function signatures must be preserved exactly
|
||||
- ALL `await` expressions must be maintained (unless replaced with functionally equivalent async operations)
|
||||
- Do NOT convert async functions to synchronous functions
|
||||
- Preserve concurrent execution patterns and async context managers
|
||||
- Maintain proper async/await flow and exception handling
|
||||
|
||||
**Async Optimization Guidelines:**
|
||||
- Consider using `asyncio.gather()` for concurrent execution when beneficial
|
||||
- Replace blocking operations with async equivalents (e.g., `time.sleep()` with `asyncio.sleep()`, sync I/O with async libraries)
|
||||
- Optimize async I/O operations and batching where appropriate
|
||||
- Use `asyncio.to_thread()` for CPU-intensive tasks to avoid blocking the event loop
|
||||
- Optimize async context managers and async iterators
|
||||
- Maintain async exception handling and resource cleanup
|
||||
|
||||
{source_code}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
|
||||
|
||||
Here are some READ ONLY Python module code snippets for understanding internal codebase dependencies of the Python program you will be optimizing, provided for your reference.
|
||||
Do not edit these Python modules, treat them as *read-only*.
|
||||
{dependency_code}
|
||||
|
|
@ -0,0 +1 @@
|
|||
When making modifications to the `__init__` function, do not change the values and behavior of any existing instance attributes. You may only add new attributes.
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
# JIT Compilation Optimization Guidelines
|
||||
|
||||
## Decision Logic (Framework Router)
|
||||
Before rewriting, analyze the imports and data structures to select the correct backend:
|
||||
- IF the code uses `torch.nn` or `torch.Tensor` → Use **PyTorch**.
|
||||
- IF the code uses `tensorflow`, `keras`, or `tf.Tensor` → Use **TensorFlow**.
|
||||
- IF the code uses `jax.numpy` or `jax.grad` → Use **JAX**.
|
||||
- IF the code uses standard `numpy` arrays and Python loops → Use **Numba**.
|
||||
- IF the code relies heavily on un-compilable Python features (pandas, dynamic dicts, strings) → **Return Original Code**.
|
||||
|
||||
## When NOT to Use JIT Compilation
|
||||
|
||||
Do NOT generate JIT-compiled candidates in the following scenarios. Instead, focus on other optimization strategies:
|
||||
|
||||
### 1. I/O-Bound Functions
|
||||
Functions that primarily perform file operations, network requests, database queries, or other I/O operations will not benefit from JIT compilation. The bottleneck is external latency, not CPU computation.
|
||||
|
||||
### 2. Functions Called Infrequently
|
||||
JIT compilation has an upfront compilation cost. If a function is only called once or very rarely, the compilation overhead will exceed any runtime savings. JIT is best for hot paths called many times.
|
||||
|
||||
### 3. Heavy Use of Unsupported Python Features
|
||||
Numba's nopython mode does not support:
|
||||
- Dictionary comprehensions or complex dict operations
|
||||
- String manipulation (except basic operations)
|
||||
- List comprehensions with complex logic
|
||||
- Most Python standard library functions
|
||||
- Custom Python classes and objects
|
||||
- Exception handling with complex try/except blocks
|
||||
- Generators and iterators (limited support)
|
||||
- Recursion with varying depths (can cause issues)
|
||||
|
||||
If the function relies heavily on these features, JIT compilation will either fail or fall back to slow object mode.
|
||||
|
||||
### 4. Functions with Dynamic or Heterogeneous Types
|
||||
JIT compilers optimize for specific type signatures. Functions that receive different types on each call will trigger recompilation, negating performance benefits.
|
||||
|
||||
### 5. Small/Trivial Functions
|
||||
Functions with minimal computation (e.g., simple getters, one-line calculations, basic conditionals) have negligible runtime. JIT overhead is not justified.
|
||||
|
||||
### 6. Code Using Incompatible Libraries
|
||||
Many Python libraries are not JIT-compatible:
|
||||
- pandas operations (use vectorized pandas methods instead)
|
||||
- Most third-party libraries
|
||||
- Complex numpy operations with object dtypes
|
||||
- Symbolic math libraries (sympy, etc.)
|
||||
|
||||
### 7. Functions with Heavy Object Creation
|
||||
Code that creates many Python objects, uses class instances extensively, or relies on Python's dynamic nature will not JIT well.
|
||||
|
||||
## JIT Output Format (CRITICAL)
|
||||
|
||||
When you determine that JIT compilation is viable for the given code:
|
||||
- Apply the JIT decorator **directly** to the output function (e.g., `@numba.njit`, `@torch.compile`, `@tf.function`, `@jax.jit`).
|
||||
- Add the necessary import (e.g., `import numba`) at the top of the code.
|
||||
- Do **NOT** create conditional fallback paths. Never wrap JIT usage in `try/except ImportError`, `if HAS_NUMBA`, or any similar if/else branching that falls back to a non-JIT version.
|
||||
- Do **NOT** create a separate "fast" helper function alongside the original. The output must be a single, clean function with the JIT decorator applied.
|
||||
- If JIT is not viable for this code, optimize it using other strategies without any JIT decorators — do not include a JIT path at all.
|
||||
|
||||
## Guidelines for Numba (`@njit`)
|
||||
Use Numba when the code:
|
||||
- Performs numerical computations with NumPy arrays
|
||||
- Contains loops over array elements
|
||||
- Uses basic NumPy operations (arithmetic, indexing, slicing)
|
||||
- Does not use unsupported Python features (classes, dictionaries with non-scalar keys, etc.)
|
||||
|
||||
### Numba Best Practices:
|
||||
- Add `@numba.njit` or `@numba.njit(cache=True)` decorator
|
||||
- Use `numba.prange` instead of `range` for parallel loops when appropriate
|
||||
- Avoid Python objects inside JIT functions (use primitive types and arrays)
|
||||
- Pre-allocate output arrays instead of using list.append()
|
||||
- Do not use `fastmath=True`
|
||||
- Ensure all array dtypes are consistent
|
||||
|
||||
### Numba Limitations to Avoid:
|
||||
- No dictionary comprehensions or set operations
|
||||
- No string operations
|
||||
- No class instances (use structured arrays instead)
|
||||
- Limited support for keyword arguments in nested calls
|
||||
|
||||
## Guidelines for PyTorch (`torch.compile`)
|
||||
Use torch.compile when the code:
|
||||
- Defines or uses PyTorch neural network modules (nn.Module)
|
||||
- Performs tensor operations with torch tensors
|
||||
- Contains forward passes or training loops
|
||||
|
||||
### torch.compile Best Practices:
|
||||
- Apply `@torch.compile()` to the model or specific functions
|
||||
- Use `mode="reduce-overhead"` for small models with many calls
|
||||
- Use `mode="max-autotune"` for maximum performance when compile time is acceptable
|
||||
- Consider `fullgraph=True` to ensure the entire function is compiled
|
||||
- Use `dynamic=True` if input shapes vary
|
||||
|
||||
### torch.compile Limitations to Avoid:
|
||||
- No Python print statements inside compiled functions (causes graph breaks)
|
||||
- Avoid data-dependent control flow (if statements based on tensor values)
|
||||
- No unsupported Python built-ins (input(), eval(), exec())
|
||||
- Avoid in-place operations on input tensors when using certain backends
|
||||
- No operations with data-dependent output shapes (use `dynamic=True` if needed)
|
||||
- Avoid calling non-PyTorch functions that can't be traced (use `torch.compiler.disable` to skip them)
|
||||
- No Python generators or iterators over tensors
|
||||
- Avoid global variable modifications inside compiled functions
|
||||
- Limited support for custom autograd functions (may cause graph breaks)
|
||||
- No third-party libraries that aren't torch-compatible inside the compiled region
|
||||
|
||||
## Guidelines for TensorFlow (`tf.function`)
|
||||
Use tf.function when the code:
|
||||
- Uses TensorFlow tensors and operations
|
||||
- Defines or uses Keras models (tf.keras.Model)
|
||||
- Contains training loops or inference pipelines
|
||||
- Performs tensor computations with tf.* operations
|
||||
|
||||
### tf.function Best Practices:
|
||||
- Add `@tf.function` decorator to functions performing tensor operations
|
||||
- Use `jit_compile=True` for XLA compilation: `@tf.function(jit_compile=True)`
|
||||
- Specify `input_signature` for fixed input shapes to avoid retracing
|
||||
- Use `experimental_relax_shapes=True` if input shapes vary slightly
|
||||
- Avoid Python side effects inside tf.function (print, list.append, etc.)
|
||||
- Use tf.print() instead of print() for debugging inside traced functions
|
||||
|
||||
### tf.function Limitations to Avoid:
|
||||
- No Python control flow that depends on tensor values (use tf.cond, tf.while_loop instead)
|
||||
- Avoid creating variables inside tf.function (create them outside or in __init__)
|
||||
- No unsupported Python operations (file I/O, random.random(), etc.)
|
||||
- Minimize Python object creation inside the function
|
||||
|
||||
## Guidelines for JAX (`jax.jit`)
|
||||
Use jax.jit when the code:
|
||||
- Uses JAX arrays (jax.numpy operations)
|
||||
- Performs functional numerical computations
|
||||
- Contains pure functions without side effects
|
||||
- Uses JAX transformations (grad, vmap, pmap)
|
||||
|
||||
### jax.jit Best Practices:
|
||||
- Add `@jax.jit` decorator to pure functions
|
||||
- Use `static_argnums` for arguments that should not be traced (e.g., axis parameters, shapes)
|
||||
- Use `static_argnames` for keyword arguments that should be static
|
||||
- Prefer `jax.numpy` over `numpy` for array operations inside jitted functions
|
||||
- Use `donate_argnums` to allow JAX to reuse input buffers for outputs
|
||||
- Combine with other JAX transforms: `jax.jit(jax.vmap(fn))` for batched operations
|
||||
|
||||
### jax.jit Limitations to Avoid:
|
||||
- No in-place mutations (JAX arrays are immutable)
|
||||
- No Python side effects (print, file I/O, global variable modifications)
|
||||
- No data-dependent control flow with dynamic shapes (use jax.lax.cond, jax.lax.fori_loop instead)
|
||||
- Avoid Python loops over traced values (use jax.lax.scan or jax.vmap)
|
||||
- No NumPy arrays inside jitted functions (convert to JAX arrays first)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
Here are the results of the line profiling of the Python program you will be optimizing, provided for your reference.
|
||||
{line_profiler_results}
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
### Code Input format
|
||||
You will receive Markdown code blocks containing multiple Python file segments. Each file is represented as a separate Markdown code block with the format ```python:<path_of_the_file>.
|
||||
|
||||
Example:
|
||||
```python:main/app1.py
|
||||
def say_hello_from_app1():
|
||||
print("Hello")
|
||||
```
|
||||
```python:main/app2.py
|
||||
def say_hello_from_app2():
|
||||
print("Hello")
|
||||
```
|
||||
|
||||
**IMPORTANT RULES: (CRITICAL)**
|
||||
- Each Markdown code block represents a **separate Python module**
|
||||
- The **first file** contains the function to optimize. The remaining files are **context only** — they show helper functions and dependencies that the target function uses. Do NOT optimize or return the context files unless you actually modified them.
|
||||
- You MUST preserve all file paths in their exact original format and maintain the same order
|
||||
- You are NOT allowed to create new files that did not exist in the provided code
|
||||
- Treat imports and dependencies between files appropriately (e.g., if app2.py imports from app1.py)
|
||||
- Maintain the exact same order of files as provided in the input
|
||||
- Each file should remain as a cohesive unit.
|
||||
- Only return files you actually modified — you do NOT need to return unchanged context files
|
||||
|
||||
### Explanation Style:
|
||||
Keep explanations **developer-focused and concise**. Focus on:
|
||||
- **What** specific optimizations were applied.
|
||||
- **Performance impact** when significant (e.g., "Reduced time complexity from O(n²) to O(n)")
|
||||
- **Key changes** that affect behavior or dependencies
|
||||
- Avoid mentioning obvious preservation details (file structure, imports, signatures) unless they were specifically modified
|
||||
|
||||
Your response **MUST** contain:
|
||||
1. A **brief, technical explanation** of optimizations applied
|
||||
2. **One or more Markdown code block(s)** for the files you modified, preserving their original file paths
|
||||
|
||||
**Any deviation from this format is incorrect. {EXPLANATION_THEN_CODE}**
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
The current measured runtime of this function is {baseline_runtime_ns} nanoseconds ({baseline_runtime_human}) over {loop_count} benchmark loops. Focus your optimization on changes that will meaningfully reduce this runtime.
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
You are a professional computer programmer who specializes in writing high-performance Python code. Your goal is to optimize the runtime and memory efficiency of the provided code through safe and meaningful rewrites that would pass senior-level code review.
|
||||
|
||||
**Behavioral Preservation (CRITICAL)**
|
||||
- Do NOT rename functions or change their signatures.
|
||||
- You MUST NOT change the behavior, return values, side effects, printed/logged output, or raised exceptions - they MUST remain exactly the same.
|
||||
- Do NOT mutate inputs in a different way than the original implementation.
|
||||
- The same exception types should be raised in the same circumstances.
|
||||
- Preserve existing type annotations - all function parameters, return types, and variable annotations must be preserved exactly as written.
|
||||
- **Preserve the original code style**: Keep existing variable names unless the logic fundamentally changes
|
||||
- Preserve ALL existing comments exactly as written, unless the corresponding code logic is changed or the comment becomes factually incorrect
|
||||
- Avoid excessive inline comments - only add new comments for significant or non-obvious logic changes
|
||||
{critical_instructions}
|
||||
|
||||
**Code Style & Structure**
|
||||
- Do NOT replace walrus operators (`:=`) for optimization purposes.
|
||||
- DO NOT introduce attribute lookup optimizations. The performance improvements are minimal and come at a substantial cost to readability.
|
||||
- Keep `assert` statements as-is - do NOT convert them to `if/raise AssertionError` patterns, it doesn't improve the performance.
|
||||
- **DO NOT convert `isinstance()` checks to `type()` checks**. `isinstance()` correctly handles inheritance and subclasses, while `type()` checks are incorrect for subclass instances and represent a micro-optimization that should be avoided.
|
||||
- You may write new helper functions that do not already exist in the codebase.
|
||||
- Avoid purely stylistic changes unless they result in noticeable performance improvements
|
||||
|
||||
**Optimization Focus**
|
||||
- Create production-ready code that professional programmers would merge without further edits
|
||||
- Prioritize changes that provide measurable runtime or memory efficiency gains
|
||||
|
||||
**Code Quality Standards**
|
||||
- Ensure all optimizations are safe and would pass senior-level code review
|
||||
- Maintain code readability and maintainability alongside performance improvements
|
||||
|
||||
**Response Format (REQUIRED)**
|
||||
- ALWAYS start your response with a brief explanation (2-4 sentences) of what optimization you made and why it improves performance
|
||||
- Then provide the optimized code in a markdown code block
|
||||
- Example format:
|
||||
```
|
||||
**Optimization Explanation:**
|
||||
[Your explanation here describing the optimization technique and expected performance improvement]
|
||||
|
||||
```python:filename.py
|
||||
[optimized code]
|
||||
```
|
||||
```
|
||||
|
||||
The current Python version is {python_version_str}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
|
||||
Here are tests that define the expected behavior of the function you are optimizing. Your optimization MUST produce identical results for all these test cases.
|
||||
|
||||
Pay special attention to hand-written unit tests — they encode the developer's explicit behavioral expectations and edge cases. Any optimization that changes the output for these inputs is incorrect.
|
||||
|
||||
{test_input_examples}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
Rewrite this python program to run faster.
|
||||
|
||||
{source_code}
|
||||
775
packages/codeflash-api/src/codeflash_api/optimize/_context.py
Normal file
775
packages/codeflash-api/src/codeflash_api/optimize/_context.py
Normal file
|
|
@ -0,0 +1,775 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash_api.diff._base import DiffMethod
|
||||
from codeflash_api.diff._search_replace import SearchAndReplaceDiff
|
||||
from codeflash_api.diff._v4a import V4ADiff
|
||||
from codeflash_api.languages.python._cst_utils import (
|
||||
find_init,
|
||||
parse_module_to_cst,
|
||||
)
|
||||
from codeflash_api.languages.python._markdown import (
|
||||
extract_code_block_with_context,
|
||||
group_code,
|
||||
is_multi_context,
|
||||
split_markdown_code,
|
||||
wrap_code_in_markdown,
|
||||
)
|
||||
from codeflash_api.languages.python._postprocess import (
|
||||
OptimizationCandidate,
|
||||
optimizations_postprocessing_pipeline,
|
||||
)
|
||||
from codeflash_api.optimize.schemas import OptimizeResponseItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_api.diff._base import Diff
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_PROMPTS_DIR = Path(__file__).parent.parent / "languages" / "python" / "prompts"
|
||||
|
||||
EXPLANATION_THEN_CODE = (
|
||||
"Begin your response with a short explanation of the changes"
|
||||
" (without any title or heading like 'Explanation' or 'Changes'),"
|
||||
" then output the code."
|
||||
)
|
||||
FULL_CODE_PROMPT_INSTRUCTIONS = (
|
||||
"- Always provide the FULL, updated content of the artifact."
|
||||
" This means:\n"
|
||||
" - Include ALL updated code, even if some parts are"
|
||||
" unchanged\n"
|
||||
' - NEVER use placeholders like "// rest of the code'
|
||||
' remains the same..." or "<- leave original code here ->"\n'
|
||||
)
|
||||
|
||||
MARKDOWN_CONTEXT_PROMPT = (
|
||||
(_PROMPTS_DIR / "multi_file_code_format.md")
|
||||
.read_text()
|
||||
.format(EXPLANATION_THEN_CODE=EXPLANATION_THEN_CODE)
|
||||
)
|
||||
DEPS_CONTEXT_PROMPT = (
|
||||
(_PROMPTS_DIR / "dependency_context_prompt.md").read_text()
|
||||
)
|
||||
INIT_OPTIMIZATION_PROMPT = (
|
||||
(_PROMPTS_DIR / "init_optimization_prompt.md").read_text()
|
||||
)
|
||||
LINE_PROF_CONTEXT_PROMPT = (
|
||||
(_PROMPTS_DIR / "lineprof_context_prompt.md").read_text()
|
||||
)
|
||||
RUNTIME_CONTEXT_PROMPT = (
|
||||
(_PROMPTS_DIR / "runtime_context_prompt.md").read_text()
|
||||
)
|
||||
TEST_EXAMPLES_PROMPT = (
|
||||
(_PROMPTS_DIR / "test_examples_prompt.md").read_text()
|
||||
)
|
||||
|
||||
V4A_DIFF_FORMAT_PROMPT = """Describe the changes using the V4A diff format, enclosed within `*** Begin Patch` and `*** End Patch` markers.
|
||||
# V4A Diff Format Rules:
|
||||
|
||||
Your entire response containing the patch MUST start with `*** Begin Patch` on a line by itself.
|
||||
Your entire response containing the patch MUST end with `*** End Patch` on a line by itself.
|
||||
Use the *FULL* file path, as shown to you by the user.
|
||||
|
||||
For each file you need to modify, start with a marker line:
|
||||
*** Update File: [path/to/file]
|
||||
|
||||
You ONLY update existing files, do not add or remove files.
|
||||
|
||||
Each file MUST appear only once in the patch.
|
||||
Consolidate all changes for that file into the same block.
|
||||
|
||||
For `Update` actions, describe each snippet of code that needs to be changed using the following format:
|
||||
1. Context lines: Include 3 lines of context *before* the change. These lines MUST start with a single space ` `.
|
||||
2. Lines to remove: Precede each line to be removed with a minus sign `-`.
|
||||
3. Lines to add: Precede each line to be added with a plus sign `+`.
|
||||
4. Context lines: Include 3 lines of context *after* the change. These lines MUST start with a single space ` `.
|
||||
|
||||
Context lines MUST exactly match the existing file content, character for character, including indentation.
|
||||
If a change is near the beginning or end of the file, include fewer than 3 context lines as appropriate.
|
||||
If 3 lines of context is insufficient to uniquely identify the snippet, use `@@ [CLASS_OR_FUNCTION_NAME]` markers on their own lines *before* the context lines to specify the scope.
|
||||
Do not include line numbers.
|
||||
|
||||
Only create patches for files that the user has provided.
|
||||
|
||||
ONLY EVER RETURN CODE IN THE SPECIFIED V4A DIFF FORMAT, followed by a short explanation of the changes.
|
||||
"""
|
||||
|
||||
SEARCH_AND_REPLACE_FORMAT_PROMPT = """Describe the changes using the search/replace diff format with xml-style tags, like:
|
||||
|
||||
<replace_in_file>
|
||||
<path>src/main.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
a = 2
|
||||
=======
|
||||
a = 3
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
|
||||
Parameters:
|
||||
- path: (required) The path of the file to modify
|
||||
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
|
||||
```
|
||||
<<<<<<< SEARCH
|
||||
[exact content to find]
|
||||
=======
|
||||
[new content to replace with]
|
||||
>>>>>>> REPLACE
|
||||
```
|
||||
Critical rules:
|
||||
1. SEARCH content must match the associated file section to find EXACTLY.
|
||||
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
|
||||
3. Keep SEARCH/REPLACE blocks concise.
|
||||
4. Special operations:
|
||||
* To move code: Use two SEARCH/REPLACE blocks
|
||||
* To delete code: Use empty REPLACE section
|
||||
"""
|
||||
|
||||
|
||||
def _humanize_ns(ns: int) -> str:
|
||||
"""
|
||||
Convert nanoseconds to a human-readable string.
|
||||
"""
|
||||
if ns < 1_000:
|
||||
return f"{ns}ns"
|
||||
if ns < 1_000_000:
|
||||
return f"{ns / 1_000:.1f}us"
|
||||
if ns < 1_000_000_000:
|
||||
return f"{ns / 1_000_000:.1f}ms"
|
||||
return f"{ns / 1_000_000_000:.2f}s"
|
||||
|
||||
|
||||
def parse_python_version(version: str | None) -> tuple[int, int, int]:
|
||||
"""
|
||||
Parse a version string like '3.12.9' into a tuple.
|
||||
"""
|
||||
if not version:
|
||||
msg = "Python version is required"
|
||||
raise ValueError(msg)
|
||||
if len(version) > 30:
|
||||
msg = "Invalid version format"
|
||||
raise ValueError(msg)
|
||||
parts = version.split(".")
|
||||
if len(parts) != 3:
|
||||
msg = "Invalid version format"
|
||||
raise ValueError(msg)
|
||||
patch_match = re.match(r"\d+", parts[2])
|
||||
if not patch_match:
|
||||
msg = "Invalid patch version"
|
||||
raise ValueError(msg)
|
||||
major = int(parts[0])
|
||||
minor = int(parts[1])
|
||||
patch = int(patch_match.group())
|
||||
if major != 3:
|
||||
msg = "Only Python 3 is supported"
|
||||
raise ValueError(msg)
|
||||
if minor < 9 or minor > 15:
|
||||
msg = "Only Python 3.9 and above is supported"
|
||||
raise ValueError(msg)
|
||||
if patch < 0 or patch >= 100:
|
||||
msg = "Invalid version format"
|
||||
raise ValueError(msg)
|
||||
return (major, minor, patch)
|
||||
|
||||
|
||||
def validate_trace_id(trace_id: str) -> bool:
|
||||
"""
|
||||
Check that *trace_id* is a valid UUIDv4.
|
||||
"""
|
||||
normalized = trace_id
|
||||
if trace_id[-4:] in ("EXP0", "EXP1"):
|
||||
normalized = trace_id[:-4] + "0000"
|
||||
try:
|
||||
uuid_obj = uuid.UUID(normalized, version=4)
|
||||
return str(uuid_obj) == normalized
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
class CodeStrAndExplanation:
|
||||
"""
|
||||
Extracted code and explanation from LLM output.
|
||||
"""
|
||||
|
||||
__slots__ = ("code", "explanation")
|
||||
|
||||
def __init__(self, code: str, explanation: str) -> None:
|
||||
self.code = code
|
||||
self.explanation = explanation
|
||||
|
||||
|
||||
class BaseOptimizerContext:
|
||||
"""
|
||||
Base context for Python optimization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_system_prompt: str,
|
||||
base_user_prompt: str,
|
||||
source_code: str,
|
||||
) -> None:
|
||||
self.base_system_prompt = base_system_prompt
|
||||
self.base_user_prompt = base_user_prompt
|
||||
self.source_code = source_code
|
||||
self.extracted_code_and_expl: CodeStrAndExplanation | None = None
|
||||
self.code_and_explanation_before_post_processing: dict[
|
||||
str, CodeStrAndExplanation
|
||||
] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_context(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
source_code: str,
|
||||
diff_method: DiffMethod = DiffMethod.NO_DIFF,
|
||||
) -> BaseOptimizerContext:
|
||||
"""
|
||||
Factory: choose Single or Multi based on source format.
|
||||
"""
|
||||
if is_multi_context(source_code):
|
||||
file_to_code = split_markdown_code(source_code)
|
||||
files = list(file_to_code.keys())
|
||||
if len(files) == 1 and diff_method == DiffMethod.NO_DIFF:
|
||||
file_name = files[0]
|
||||
return SingleOptimizerContext(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
source_code=file_to_code[file_name],
|
||||
file_name=file_name,
|
||||
)
|
||||
return MultiOptimizerContext(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
source_code,
|
||||
diff_method=diff_method,
|
||||
)
|
||||
return SingleOptimizerContext(
|
||||
system_prompt, user_prompt, source_code
|
||||
)
|
||||
|
||||
def get_system_prompt(
|
||||
self, python_version_str: str
|
||||
) -> str:
|
||||
"""
|
||||
Return the formatted system prompt.
|
||||
"""
|
||||
return self.base_system_prompt
|
||||
|
||||
def get_user_prompt(
|
||||
self,
|
||||
dependency_code: str,
|
||||
line_profiler_results: str | None,
|
||||
*,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return the formatted user prompt.
|
||||
"""
|
||||
return self.base_user_prompt
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(
|
||||
self, content: str
|
||||
) -> CodeStrAndExplanation:
|
||||
"""
|
||||
Parse LLM output into code and explanation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def parse_and_generate_candidate_schema(
|
||||
self,
|
||||
) -> OptimizeResponseItem | None:
|
||||
"""
|
||||
Post-process extracted code into a response item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_valid_code(self) -> bool:
|
||||
"""
|
||||
Check if extracted code is non-empty.
|
||||
"""
|
||||
if self.extracted_code_and_expl is None:
|
||||
return False
|
||||
code = self.extracted_code_and_expl.code
|
||||
return code is not None and code.strip() != ""
|
||||
|
||||
def validate_and_parse_source_code(
|
||||
self,
|
||||
code: str,
|
||||
feature_version: tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Validate that *code* is syntactically valid Python.
|
||||
"""
|
||||
final_code = code or self.source_code
|
||||
parsed = ast.parse(final_code, feature_version=feature_version)
|
||||
if not parsed.body:
|
||||
msg = "Source code is empty or invalid"
|
||||
raise SyntaxError(msg)
|
||||
|
||||
|
||||
class SingleOptimizerContext(BaseOptimizerContext):
|
||||
"""
|
||||
Context for single-file optimizations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_system_prompt: str,
|
||||
base_user_prompt: str,
|
||||
source_code: str,
|
||||
file_name: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
base_system_prompt, base_user_prompt, source_code
|
||||
)
|
||||
self.file_name = file_name
|
||||
|
||||
def get_system_prompt(
|
||||
self, python_version_str: str
|
||||
) -> str:
|
||||
"""
|
||||
Format system prompt with full-code instructions.
|
||||
"""
|
||||
return (
|
||||
self.base_system_prompt.format(
|
||||
python_version_str=python_version_str,
|
||||
critical_instructions=FULL_CODE_PROMPT_INSTRUCTIONS,
|
||||
)
|
||||
+ "\n"
|
||||
+ EXPLANATION_THEN_CODE
|
||||
)
|
||||
|
||||
def get_user_prompt(
|
||||
self,
|
||||
dependency_code: str,
|
||||
line_profiler_results: str | None,
|
||||
*,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build user prompt with context sections.
|
||||
"""
|
||||
markdown_source_code = wrap_code_in_markdown(self.source_code)
|
||||
runtime_part = ""
|
||||
if (
|
||||
baseline_runtime_ns is not None
|
||||
and loop_count is not None
|
||||
):
|
||||
runtime_part = RUNTIME_CONTEXT_PROMPT.format(
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
baseline_runtime_human=_humanize_ns(
|
||||
baseline_runtime_ns
|
||||
),
|
||||
loop_count=loop_count,
|
||||
)
|
||||
test_examples_part = ""
|
||||
if test_input_examples:
|
||||
test_examples_part = TEST_EXAMPLES_PROMPT.format(
|
||||
test_input_examples=test_input_examples
|
||||
)
|
||||
has_init = find_init(ast.parse(self.source_code))
|
||||
return (
|
||||
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
|
||||
f"{self.base_user_prompt.format(source_code=markdown_source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
|
||||
f"{runtime_part}"
|
||||
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
|
||||
f"{test_examples_part}"
|
||||
)
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(
|
||||
self, content: str
|
||||
) -> CodeStrAndExplanation:
|
||||
"""
|
||||
Extract code block and explanation from LLM response.
|
||||
"""
|
||||
code = ""
|
||||
explanation = ""
|
||||
if result := extract_code_block_with_context(
|
||||
content, language="python"
|
||||
):
|
||||
explanation, code, _ = result
|
||||
extracted = CodeStrAndExplanation(
|
||||
code=code, explanation=explanation
|
||||
)
|
||||
self.extracted_code_and_expl = extracted
|
||||
return extracted
|
||||
|
||||
def parse_and_generate_candidate_schema(
|
||||
self,
|
||||
) -> OptimizeResponseItem | None:
|
||||
"""
|
||||
Post-process extracted code into a response item.
|
||||
"""
|
||||
if self.extracted_code_and_expl is None:
|
||||
msg = "Call extract_code_and_explanation_from_llm_res first"
|
||||
raise AttributeError(msg)
|
||||
extracted = self.extracted_code_and_expl
|
||||
if extracted.code == "":
|
||||
return None
|
||||
try:
|
||||
op_id = str(uuid.uuid4())
|
||||
cst_module = parse_module_to_cst(extracted.code)
|
||||
self.code_and_explanation_before_post_processing[
|
||||
op_id
|
||||
] = CodeStrAndExplanation(
|
||||
code=cst_module.code,
|
||||
explanation=extracted.explanation,
|
||||
)
|
||||
original_cst_module = parse_module_to_cst(
|
||||
self.source_code
|
||||
)
|
||||
postprocessed_list = optimizations_postprocessing_pipeline(
|
||||
original_cst_module,
|
||||
[
|
||||
OptimizationCandidate(
|
||||
cst_module=cst_module,
|
||||
id=op_id,
|
||||
explanation=extracted.explanation,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if len(postprocessed_list) == 0:
|
||||
return None
|
||||
|
||||
postprocessed = postprocessed_list[0]
|
||||
source = (
|
||||
group_code(
|
||||
{self.file_name: postprocessed.cst_module.code}
|
||||
)
|
||||
if self.file_name is not None
|
||||
else postprocessed.cst_module.code
|
||||
)
|
||||
return OptimizeResponseItem(
|
||||
explanation=postprocessed.explanation,
|
||||
optimization_id=postprocessed.id,
|
||||
source_code=source,
|
||||
)
|
||||
|
||||
except (
|
||||
ValueError,
|
||||
cst.ParserSyntaxError,
|
||||
):
|
||||
log.warning("Error parsing optimization result", exc_info=True)
|
||||
return None
|
||||
|
||||
def is_valid_code(self) -> bool:
|
||||
"""
|
||||
Check if extracted code is valid.
|
||||
"""
|
||||
return super().is_valid_code()
|
||||
|
||||
def validate_and_parse_source_code(
|
||||
self,
|
||||
code: str,
|
||||
feature_version: tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Validate single-file source code.
|
||||
"""
|
||||
compile(self.source_code, "source_code", "exec")
|
||||
super().validate_and_parse_source_code(
|
||||
self.source_code, feature_version
|
||||
)
|
||||
|
||||
|
||||
class MultiOptimizerContext(BaseOptimizerContext):
|
||||
"""
|
||||
Context for multi-file optimizations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_system_prompt: str,
|
||||
base_user_prompt: str,
|
||||
source_code: str,
|
||||
diff_method: DiffMethod,
|
||||
) -> None:
|
||||
self.diff_method = diff_method
|
||||
super().__init__(
|
||||
base_system_prompt, base_user_prompt, source_code
|
||||
)
|
||||
self.original_file_to_code = split_markdown_code(source_code)
|
||||
self._original_files_set = set(
|
||||
self.original_file_to_code.keys()
|
||||
)
|
||||
|
||||
def get_system_prompt(
|
||||
self, python_version_str: str
|
||||
) -> str:
|
||||
"""
|
||||
Format system prompt based on diff method.
|
||||
"""
|
||||
critical_instructions = ""
|
||||
code_format_instructions = ""
|
||||
if self.diff_method == DiffMethod.V4A:
|
||||
code_format_instructions = V4A_DIFF_FORMAT_PROMPT
|
||||
critical_instructions = (
|
||||
"Begin your response with the diff followed by a"
|
||||
" short explanation of the changes (without any"
|
||||
" title or heading like 'Explanation' or"
|
||||
" 'Changes')."
|
||||
)
|
||||
elif self.diff_method == DiffMethod.SEARCH_AND_REPLACE:
|
||||
code_format_instructions = (
|
||||
SEARCH_AND_REPLACE_FORMAT_PROMPT
|
||||
)
|
||||
critical_instructions = (
|
||||
"Begin your response with the diff followed by a"
|
||||
" short explanation of the changes (without any"
|
||||
" title or heading like 'Explanation' or"
|
||||
" 'Changes')."
|
||||
)
|
||||
elif self.diff_method == DiffMethod.NO_DIFF:
|
||||
code_format_instructions = MARKDOWN_CONTEXT_PROMPT
|
||||
critical_instructions = FULL_CODE_PROMPT_INSTRUCTIONS
|
||||
return (
|
||||
self.base_system_prompt.format(
|
||||
python_version_str=python_version_str,
|
||||
critical_instructions=critical_instructions,
|
||||
)
|
||||
+ "\n"
|
||||
+ code_format_instructions
|
||||
)
|
||||
|
||||
def get_user_prompt(
|
||||
self,
|
||||
dependency_code: str,
|
||||
line_profiler_results: str | None,
|
||||
*,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build user prompt for multi-file context.
|
||||
"""
|
||||
has_init = any(
|
||||
find_init(ast.parse(code))
|
||||
for code in self.original_file_to_code.values()
|
||||
)
|
||||
runtime_part = ""
|
||||
if (
|
||||
baseline_runtime_ns is not None
|
||||
and loop_count is not None
|
||||
):
|
||||
runtime_part = RUNTIME_CONTEXT_PROMPT.format(
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
baseline_runtime_human=_humanize_ns(
|
||||
baseline_runtime_ns
|
||||
),
|
||||
loop_count=loop_count,
|
||||
)
|
||||
test_examples_part = ""
|
||||
if test_input_examples:
|
||||
test_examples_part = TEST_EXAMPLES_PROMPT.format(
|
||||
test_input_examples=test_input_examples
|
||||
)
|
||||
return (
|
||||
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
|
||||
f"{self.base_user_prompt.format(source_code=self.source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
|
||||
f"{runtime_part}"
|
||||
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
|
||||
f"{test_examples_part}"
|
||||
)
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(
|
||||
self, content: str
|
||||
) -> CodeStrAndExplanation:
|
||||
"""
|
||||
Extract code from LLM response based on diff method.
|
||||
"""
|
||||
markdown_code = ""
|
||||
explanation = ""
|
||||
if self.diff_method == DiffMethod.NO_DIFF:
|
||||
explanation = (
|
||||
content.split("```python:", maxsplit=1)[
|
||||
0
|
||||
].strip()
|
||||
)
|
||||
markdown_code = group_code(
|
||||
split_markdown_code(content)
|
||||
)
|
||||
elif self.diff_method == DiffMethod.V4A:
|
||||
diff: Diff = V4ADiff(
|
||||
content=content,
|
||||
source_code=self.original_file_to_code,
|
||||
)
|
||||
new_file_to_code = diff.run()
|
||||
markdown_code = group_code(new_file_to_code)
|
||||
patch_end_marker = "*** End Patch"
|
||||
index_of_end = content.rfind(patch_end_marker)
|
||||
explanation = content[
|
||||
index_of_end + len(patch_end_marker) :
|
||||
].strip()
|
||||
elif self.diff_method == DiffMethod.SEARCH_AND_REPLACE:
|
||||
diff = SearchAndReplaceDiff(
|
||||
content=content,
|
||||
source_code=self.original_file_to_code,
|
||||
)
|
||||
new_file_to_code = diff.run()
|
||||
end_tag = "</replace_in_file>"
|
||||
index_of_end = content.rfind(end_tag)
|
||||
explanation = content[
|
||||
index_of_end + len(end_tag) :
|
||||
].strip()
|
||||
markdown_code = group_code(new_file_to_code)
|
||||
|
||||
result = CodeStrAndExplanation(
|
||||
code=markdown_code, explanation=explanation
|
||||
)
|
||||
self.extracted_code_and_expl = result
|
||||
return result
|
||||
|
||||
def parse_and_generate_candidate_schema(
|
||||
self,
|
||||
) -> OptimizeResponseItem | None:
|
||||
"""
|
||||
Post-process each extracted file individually.
|
||||
"""
|
||||
if self.extracted_code_and_expl is None:
|
||||
msg = "Call extract_code_and_explanation_from_llm_res first"
|
||||
raise AttributeError(msg)
|
||||
|
||||
extracted = self.extracted_code_and_expl
|
||||
if extracted.code == "":
|
||||
return None
|
||||
|
||||
original_code_files = dict(
|
||||
self.original_file_to_code.items()
|
||||
)
|
||||
extracted_code_files = dict(
|
||||
split_markdown_code(extracted.code).items()
|
||||
)
|
||||
|
||||
op_id = str(uuid.uuid4())
|
||||
new_post_processed_code: dict[str, str] = {}
|
||||
new_explanation: str = ""
|
||||
|
||||
self.code_and_explanation_before_post_processing[
|
||||
op_id
|
||||
] = CodeStrAndExplanation(
|
||||
code=extracted.code, explanation=extracted.explanation
|
||||
)
|
||||
|
||||
valid_extracted_files = {
|
||||
f: c
|
||||
for f, c in extracted_code_files.items()
|
||||
if f in original_code_files
|
||||
}
|
||||
if not valid_extracted_files:
|
||||
log.warning(
|
||||
"No matching files. Original: %s, Extracted: %s",
|
||||
list(original_code_files.keys()),
|
||||
list(extracted_code_files.keys()),
|
||||
)
|
||||
return None
|
||||
|
||||
optimization_ignored = True
|
||||
|
||||
for original_file, new_code in valid_extracted_files.items():
|
||||
original_code = original_code_files[original_file]
|
||||
try:
|
||||
new_cst_module = parse_module_to_cst(new_code)
|
||||
original_cst_module = parse_module_to_cst(
|
||||
original_code
|
||||
)
|
||||
|
||||
code_and_explanation = OptimizationCandidate(
|
||||
cst_module=new_cst_module,
|
||||
id=f"{original_file}:post-processing",
|
||||
explanation=extracted.explanation,
|
||||
)
|
||||
postprocessed_list = (
|
||||
optimizations_postprocessing_pipeline(
|
||||
original_cst_module,
|
||||
[code_and_explanation],
|
||||
)
|
||||
)
|
||||
if len(postprocessed_list) == 0:
|
||||
new_post_processed_code[original_file] = (
|
||||
new_code
|
||||
)
|
||||
continue
|
||||
|
||||
optimization_ignored = False
|
||||
postprocessed = postprocessed_list[0]
|
||||
|
||||
if new_explanation == "":
|
||||
new_explanation = postprocessed.explanation
|
||||
|
||||
new_post_processed_code[original_file] = (
|
||||
postprocessed.cst_module.code
|
||||
)
|
||||
|
||||
except (
|
||||
ValueError,
|
||||
cst.ParserSyntaxError,
|
||||
):
|
||||
log.warning(
|
||||
"Error processing file %s",
|
||||
original_file,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
if optimization_ignored:
|
||||
return None
|
||||
return OptimizeResponseItem(
|
||||
optimization_id=op_id,
|
||||
explanation=new_explanation,
|
||||
source_code=group_code(new_post_processed_code),
|
||||
)
|
||||
|
||||
def is_valid_code(self) -> bool:
|
||||
"""
|
||||
Validate extracted files are subset of originals.
|
||||
"""
|
||||
if self.extracted_code_and_expl is None:
|
||||
return False
|
||||
extracted_files = set(
|
||||
split_markdown_code(
|
||||
self.extracted_code_and_expl.code
|
||||
).keys()
|
||||
)
|
||||
|
||||
invalid_files = extracted_files - self._original_files_set
|
||||
if invalid_files:
|
||||
log.warning(
|
||||
"LLM returned files not in original: %s",
|
||||
invalid_files,
|
||||
)
|
||||
return False
|
||||
|
||||
return super().is_valid_code()
|
||||
|
||||
def validate_and_parse_source_code(
|
||||
self,
|
||||
code: str,
|
||||
feature_version: tuple[int, ...],
|
||||
) -> None:
|
||||
"""
|
||||
Validate each file in multi-file source code.
|
||||
"""
|
||||
final_code = code or self.source_code
|
||||
code_blocks = split_markdown_code(final_code).values()
|
||||
for code_block in code_blocks:
|
||||
compile(code_block, "source_code", "exec")
|
||||
super().validate_and_parse_source_code(
|
||||
code_block, feature_version
|
||||
)
|
||||
212
packages/codeflash-api/src/codeflash_api/optimize/_pipeline.py
Normal file
212
packages/codeflash-api/src/codeflash_api/optimize/_pipeline.py
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash_api.llm._client import LLMOutputUnparseableError
|
||||
from codeflash_api.llm._models import (
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5,
|
||||
LLM,
|
||||
OPENAI_GPT_5_MINI,
|
||||
)
|
||||
from codeflash_api.optimize._context import BaseOptimizerContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_api.llm._client import LLMClient
|
||||
from codeflash_api.optimize.schemas import OptimizeResponseItem
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_MODEL: LLM = OPENAI_GPT_5_MINI
|
||||
ANTHROPIC_MODEL: LLM = ANTHROPIC_CLAUDE_SONNET_4_5
|
||||
|
||||
MAX_OPTIMIZER_CALLS = 6
|
||||
|
||||
|
||||
def get_model_distribution(
|
||||
total_calls: int, max_calls: int
|
||||
) -> list[tuple[LLM, int]]:
|
||||
"""
|
||||
Split calls between OpenAI and Anthropic models.
|
||||
"""
|
||||
final_total = min(total_calls, max_calls)
|
||||
claude_calls = (final_total - 1) // 2
|
||||
gpt_calls = final_total - claude_calls
|
||||
return [
|
||||
(OPENAI_MODEL, gpt_calls),
|
||||
(ANTHROPIC_MODEL, claude_calls),
|
||||
]
|
||||
|
||||
|
||||
async def generate_optimization_candidate(
|
||||
llm_client: LLMClient,
|
||||
user_id: str,
|
||||
ctx: BaseOptimizerContext,
|
||||
trace_id: str,
|
||||
*,
|
||||
dependency_code: str | None = None,
|
||||
optimize_model: LLM = OPENAI_MODEL,
|
||||
python_version: tuple[int, int, int] = (3, 12, 9),
|
||||
call_sequence: int | None = None,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
line_profiler_results: str | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> tuple[OptimizeResponseItem | None, float | None, str]:
|
||||
"""
|
||||
Generate a single optimization candidate via LLM.
|
||||
"""
|
||||
log.info("Generating optimization candidate")
|
||||
python_version_str = ".".join(str(x) for x in python_version)
|
||||
|
||||
system_prompt = ctx.get_system_prompt(python_version_str)
|
||||
user_prompt = ctx.get_user_prompt(
|
||||
dependency_code or "",
|
||||
line_profiler_results,
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
loop_count=loop_count,
|
||||
test_input_examples=test_input_examples,
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
try:
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="optimization",
|
||||
trace_id=trace_id,
|
||||
)
|
||||
except LLMOutputUnparseableError as e:
|
||||
return None, e.cost, optimize_model.name
|
||||
except Exception:
|
||||
log.exception("Failed to generate optimization")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = output.cost
|
||||
|
||||
ctx.extract_code_and_explanation_from_llm_res(output.content)
|
||||
try:
|
||||
res = await asyncio.to_thread(
|
||||
ctx.parse_and_generate_candidate_schema
|
||||
)
|
||||
if res is not None and ctx.is_valid_code():
|
||||
return res, llm_cost, optimize_model.name
|
||||
except (ValueError, cst.ParserSyntaxError):
|
||||
log.warning(
|
||||
"Error parsing optimization result", exc_info=True
|
||||
)
|
||||
|
||||
return None, llm_cost, optimize_model.name
|
||||
|
||||
|
||||
async def optimize_python_code(
|
||||
llm_client: LLMClient,
|
||||
user_id: str,
|
||||
ctx: BaseOptimizerContext,
|
||||
trace_id: str,
|
||||
original_source_code: str,
|
||||
*,
|
||||
dependency_code: str | None = None,
|
||||
python_version: tuple[int, int, int] = (3, 12, 9),
|
||||
n_candidates: int = 0,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
line_profiler_results: str | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> tuple[
|
||||
list[OptimizeResponseItem],
|
||||
float,
|
||||
dict[str, dict[str, str]],
|
||||
dict[str, str],
|
||||
]:
|
||||
"""
|
||||
Run parallel optimizations with multiple models.
|
||||
"""
|
||||
tasks: list[
|
||||
tuple[
|
||||
asyncio.Task[
|
||||
tuple[
|
||||
OptimizeResponseItem | None,
|
||||
float | None,
|
||||
str,
|
||||
]
|
||||
],
|
||||
BaseOptimizerContext,
|
||||
]
|
||||
] = []
|
||||
call_sequence = 1
|
||||
|
||||
if n_candidates == 0:
|
||||
return [], 0.0, {}, {}
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for model, num_calls in get_model_distribution(
|
||||
n_candidates, MAX_OPTIMIZER_CALLS
|
||||
):
|
||||
for _ in range(num_calls):
|
||||
task_ctx = BaseOptimizerContext.get_dynamic_context(
|
||||
ctx.base_system_prompt,
|
||||
ctx.base_user_prompt,
|
||||
original_source_code,
|
||||
)
|
||||
lp_for_this_call = (
|
||||
line_profiler_results
|
||||
if call_sequence % 2 == 1
|
||||
else None
|
||||
)
|
||||
task = tg.create_task(
|
||||
generate_optimization_candidate(
|
||||
llm_client,
|
||||
user_id=user_id,
|
||||
ctx=task_ctx,
|
||||
trace_id=trace_id,
|
||||
dependency_code=dependency_code,
|
||||
optimize_model=model,
|
||||
python_version=python_version,
|
||||
call_sequence=call_sequence,
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
loop_count=loop_count,
|
||||
line_profiler_results=lp_for_this_call,
|
||||
test_input_examples=test_input_examples,
|
||||
)
|
||||
)
|
||||
tasks.append((task, task_ctx))
|
||||
call_sequence += 1
|
||||
|
||||
optimization_results: list[OptimizeResponseItem] = []
|
||||
total_cost = 0.0
|
||||
code_and_explanations: dict[str, dict[str, str]] = {}
|
||||
optimization_models: dict[str, str] = {}
|
||||
|
||||
for task, task_ctx in tasks:
|
||||
result, cost, model_name = task.result()
|
||||
if cost:
|
||||
total_cost += cost
|
||||
if result is not None:
|
||||
optimization_results.append(result)
|
||||
optimization_models[result.optimization_id] = (
|
||||
model_name
|
||||
)
|
||||
for (
|
||||
op_id,
|
||||
cei,
|
||||
) in (
|
||||
task_ctx.code_and_explanation_before_post_processing.items()
|
||||
):
|
||||
code_and_explanations[op_id] = {
|
||||
"code": cei.code,
|
||||
"explanation": cei.explanation,
|
||||
}
|
||||
|
||||
return (
|
||||
optimization_results,
|
||||
total_cost,
|
||||
code_and_explanations,
|
||||
optimization_models,
|
||||
)
|
||||
194
packages/codeflash-api/src/codeflash_api/optimize/_router.py
Normal file
194
packages/codeflash-api/src/codeflash_api/optimize/_router.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from codeflash_api.auth._deps import (
|
||||
check_rate_limit,
|
||||
require_auth,
|
||||
track_usage,
|
||||
)
|
||||
from codeflash_api.diff._base import DiffMethod
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_api.auth.models import AuthenticatedUser
|
||||
from codeflash_api.optimize._context import (
|
||||
BaseOptimizerContext,
|
||||
parse_python_version,
|
||||
validate_trace_id,
|
||||
)
|
||||
from codeflash_api.optimize._pipeline import optimize_python_code
|
||||
from codeflash_api.optimize.schemas import (
|
||||
OptimizeErrorResponse,
|
||||
OptimizeRequest,
|
||||
OptimizeResponse,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_PROMPTS_DIR = (
|
||||
Path(__file__).parent.parent
|
||||
/ "languages"
|
||||
/ "python"
|
||||
/ "prompts"
|
||||
)
|
||||
SYSTEM_PROMPT = (_PROMPTS_DIR / "system_prompt.md").read_text()
|
||||
USER_PROMPT = (_PROMPTS_DIR / "user_prompt.md").read_text()
|
||||
ASYNC_SYSTEM_PROMPT = (
|
||||
(_PROMPTS_DIR / "async_system_prompt.md").read_text()
|
||||
)
|
||||
ASYNC_USER_PROMPT = (
|
||||
(_PROMPTS_DIR / "async_user_prompt.md").read_text()
|
||||
)
|
||||
JIT_INSTRUCTIONS = (
|
||||
(_PROMPTS_DIR / "jit_instructions.md").read_text()
|
||||
)
|
||||
|
||||
|
||||
def _validate_request(
|
||||
data: OptimizeRequest,
|
||||
ctx: BaseOptimizerContext,
|
||||
) -> tuple[int, int, int]:
|
||||
"""
|
||||
Validate request fields and parse source code.
|
||||
"""
|
||||
if not data.source_code:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Source code cannot be empty.",
|
||||
)
|
||||
if not validate_trace_id(data.trace_id):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid trace ID. Please provide a valid UUIDv4.",
|
||||
)
|
||||
if not data.python_version:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Python version is required.",
|
||||
)
|
||||
try:
|
||||
python_version = parse_python_version(
|
||||
data.python_version
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Invalid Python version, it should look like"
|
||||
" 3.x.x. We only support Python 3.9 and above."
|
||||
),
|
||||
) from e
|
||||
|
||||
try:
|
||||
ctx.validate_and_parse_source_code(
|
||||
data.source_code,
|
||||
feature_version=python_version[:2],
|
||||
)
|
||||
except SyntaxError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Invalid source code. It is not valid Python"
|
||||
" code. Please check syntax of your code."
|
||||
),
|
||||
) from e
|
||||
|
||||
return python_version
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ai/optimize",
|
||||
response_model=OptimizeResponse,
|
||||
responses={
|
||||
400: {"model": OptimizeErrorResponse},
|
||||
422: {"model": OptimizeErrorResponse},
|
||||
500: {"model": OptimizeErrorResponse},
|
||||
},
|
||||
)
|
||||
async def optimize(
|
||||
request: Request,
|
||||
data: OptimizeRequest,
|
||||
user: Annotated[
|
||||
AuthenticatedUser, Depends(require_auth)
|
||||
],
|
||||
_rate: Annotated[None, Depends(check_rate_limit)],
|
||||
_usage: Annotated[None, Depends(track_usage)],
|
||||
) -> OptimizeResponse:
|
||||
"""
|
||||
Optimize Python code for performance using LLMs.
|
||||
"""
|
||||
system_prompt = (
|
||||
ASYNC_SYSTEM_PROMPT if data.is_async else SYSTEM_PROMPT
|
||||
)
|
||||
user_prompt = (
|
||||
ASYNC_USER_PROMPT if data.is_async else USER_PROMPT
|
||||
)
|
||||
if data.is_numerical_code:
|
||||
system_prompt += f"\n{JIT_INSTRUCTIONS}\n"
|
||||
|
||||
ctx = BaseOptimizerContext.get_dynamic_context(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
data.source_code,
|
||||
DiffMethod.NO_DIFF,
|
||||
)
|
||||
|
||||
try:
|
||||
python_version = _validate_request(data, ctx)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
"Validation error: trace_id=%s", data.trace_id
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Request validation failed.",
|
||||
) from e
|
||||
|
||||
llm_client = request.app.state.llm_client
|
||||
|
||||
(
|
||||
optimization_results,
|
||||
_llm_cost,
|
||||
_code_and_explanations,
|
||||
_optimization_models,
|
||||
) = await optimize_python_code(
|
||||
llm_client,
|
||||
user_id=user.user_id,
|
||||
ctx=ctx,
|
||||
trace_id=data.trace_id,
|
||||
original_source_code=data.source_code,
|
||||
dependency_code=data.dependency_code,
|
||||
python_version=python_version,
|
||||
n_candidates=data.n_candidates,
|
||||
baseline_runtime_ns=data.baseline_runtime_ns,
|
||||
loop_count=data.loop_count,
|
||||
line_profiler_results=data.line_profiler_results,
|
||||
test_input_examples=data.test_input_examples,
|
||||
)
|
||||
|
||||
if len(optimization_results) == 0:
|
||||
log.warning(
|
||||
"No optimizations found: trace_id=%s, repo=%s/%s,"
|
||||
" n_candidates=%d",
|
||||
data.trace_id,
|
||||
data.repo_owner,
|
||||
data.repo_name,
|
||||
data.n_candidates,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=(
|
||||
"Could not generate any optimizations."
|
||||
" Please try again."
|
||||
),
|
||||
)
|
||||
|
||||
return OptimizeResponse(optimizations=optimization_results)
|
||||
86
packages/codeflash-api/src/codeflash_api/optimize/schemas.py
Normal file
86
packages/codeflash-api/src/codeflash_api/optimize/schemas.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class OptimizedCandidateSource(str, enum.Enum):
|
||||
"""
|
||||
Source of an optimized candidate.
|
||||
"""
|
||||
|
||||
OPTIMIZE = "OPTIMIZE"
|
||||
OPTIMIZE_LP = "OPTIMIZE_LP"
|
||||
REFINE = "REFINE"
|
||||
REPAIR = "REPAIR"
|
||||
ADAPTIVE = "ADAPTIVE"
|
||||
JIT_REWRITE = "JIT_REWRITE"
|
||||
|
||||
|
||||
class OptimizeRequest(BaseModel):
|
||||
"""
|
||||
Request body for POST /ai/optimize.
|
||||
"""
|
||||
|
||||
source_code: str
|
||||
dependency_code: str | None = None
|
||||
trace_id: str
|
||||
python_version: str | None = None
|
||||
language: str = "python"
|
||||
language_version: str | None = None
|
||||
experiment_metadata: dict[str, str] | None = None
|
||||
codeflash_version: str | None = None
|
||||
current_username: str | None = None
|
||||
repo_owner: str | None = None
|
||||
repo_name: str | None = None
|
||||
is_async: bool | None = False
|
||||
n_candidates: int = 5
|
||||
is_numerical_code: bool | None = None
|
||||
rerun_trace_id: str | None = None
|
||||
baseline_runtime_ns: int | None = None
|
||||
loop_count: int | None = None
|
||||
line_profiler_results: str | None = None
|
||||
test_input_examples: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_python_version(self) -> Self:
|
||||
"""
|
||||
Resolve python_version from language_version for backward compat.
|
||||
"""
|
||||
if (
|
||||
self.python_version is None
|
||||
and self.language_version is not None
|
||||
and self.language == "python"
|
||||
):
|
||||
self.python_version = self.language_version
|
||||
return self
|
||||
|
||||
|
||||
class OptimizeResponseItem(BaseModel):
|
||||
"""
|
||||
A single optimization candidate in the response.
|
||||
"""
|
||||
|
||||
source_code: str
|
||||
explanation: str
|
||||
optimization_id: str
|
||||
parent_id: str | None = None
|
||||
optimization_event_id: str | None = None
|
||||
|
||||
|
||||
class OptimizeResponse(BaseModel):
|
||||
"""
|
||||
Successful response from POST /ai/optimize.
|
||||
"""
|
||||
|
||||
optimizations: list[OptimizeResponseItem]
|
||||
|
||||
|
||||
class OptimizeErrorResponse(BaseModel):
|
||||
"""
|
||||
Error response from POST /ai/optimize.
|
||||
"""
|
||||
|
||||
error: str
|
||||
570
packages/codeflash-api/tests/test_optimize.py
Normal file
570
packages/codeflash-api/tests/test_optimize.py
Normal file
|
|
@ -0,0 +1,570 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash_api.diff._base import DiffMethod
|
||||
from codeflash_api.languages.python._markdown import (
|
||||
extract_all_code_from_markdown,
|
||||
extract_code_block,
|
||||
extract_code_block_with_context,
|
||||
group_code,
|
||||
is_multi_context,
|
||||
split_markdown_code,
|
||||
truncate_pathological_output,
|
||||
wrap_code_in_markdown,
|
||||
)
|
||||
from codeflash_api.optimize._context import (
|
||||
BaseOptimizerContext,
|
||||
CodeStrAndExplanation,
|
||||
MultiOptimizerContext,
|
||||
SingleOptimizerContext,
|
||||
_humanize_ns,
|
||||
parse_python_version,
|
||||
validate_trace_id,
|
||||
)
|
||||
from codeflash_api.optimize._pipeline import (
|
||||
MAX_OPTIMIZER_CALLS,
|
||||
get_model_distribution,
|
||||
)
|
||||
from codeflash_api.optimize.schemas import (
|
||||
OptimizedCandidateSource,
|
||||
OptimizeErrorResponse,
|
||||
OptimizeRequest,
|
||||
OptimizeResponse,
|
||||
OptimizeResponseItem,
|
||||
)
|
||||
|
||||
|
||||
class TestTruncatePathologicalOutput:
|
||||
"""Tests for truncate_pathological_output."""
|
||||
|
||||
def test_normal_code(self) -> None:
|
||||
"""
|
||||
Normal code is returned unchanged.
|
||||
"""
|
||||
code = "x = 1\ny = 2\n"
|
||||
assert code == truncate_pathological_output(code)
|
||||
|
||||
def test_pathological(self) -> None:
|
||||
"""
|
||||
Repeated escape sequences are truncated.
|
||||
"""
|
||||
code = "x = 1\n\\''''''''''''more"
|
||||
result = truncate_pathological_output(code)
|
||||
assert "x = 1" in result
|
||||
assert "more" not in result
|
||||
|
||||
|
||||
class TestExtractAllCodeFromMarkdown:
|
||||
"""Tests for extract_all_code_from_markdown."""
|
||||
|
||||
def test_single_block(self) -> None:
|
||||
"""
|
||||
Single code block is extracted.
|
||||
"""
|
||||
md = "```python\nx = 1\n```"
|
||||
assert "x = 1\n" == extract_all_code_from_markdown(md)
|
||||
|
||||
def test_multiple_blocks(self) -> None:
|
||||
"""
|
||||
Multiple code blocks are joined.
|
||||
"""
|
||||
md = "```python\nx = 1\n```\ntext\n```python\ny = 2\n```"
|
||||
result = extract_all_code_from_markdown(md)
|
||||
assert "x = 1" in result
|
||||
assert "y = 2" in result
|
||||
|
||||
def test_no_blocks(self) -> None:
|
||||
"""
|
||||
No code blocks returns empty string.
|
||||
"""
|
||||
assert "" == extract_all_code_from_markdown("no code here")
|
||||
|
||||
|
||||
class TestExtractCodeBlock:
|
||||
"""Tests for extract_code_block."""
|
||||
|
||||
def test_basic(self) -> None:
|
||||
"""
|
||||
First code block is extracted.
|
||||
"""
|
||||
md = "```python\nx = 1\n```"
|
||||
assert "x = 1" == extract_code_block(md)
|
||||
|
||||
def test_with_filepath(self) -> None:
|
||||
"""
|
||||
Code block with filepath is extracted.
|
||||
"""
|
||||
md = "```python:main.py\nx = 1\n```"
|
||||
assert "x = 1" == extract_code_block(md)
|
||||
|
||||
def test_no_block(self) -> None:
|
||||
"""
|
||||
No code block returns None.
|
||||
"""
|
||||
assert extract_code_block("no code") is None
|
||||
|
||||
|
||||
class TestSplitMarkdownCode:
|
||||
"""Tests for split_markdown_code."""
|
||||
|
||||
def test_basic(self) -> None:
|
||||
"""
|
||||
Filepath blocks are parsed into a dict.
|
||||
"""
|
||||
md = "```python:main.py\nx = 1\n```\n```python:util.py\ny = 2\n```"
|
||||
result = split_markdown_code(md)
|
||||
assert "main.py" in result
|
||||
assert "util.py" in result
|
||||
assert "x = 1" == result["main.py"]
|
||||
|
||||
def test_plain_blocks_ignored(self) -> None:
|
||||
"""
|
||||
Plain python blocks without filepath are ignored.
|
||||
"""
|
||||
md = "```python\nx = 1\n```"
|
||||
result = split_markdown_code(md)
|
||||
assert {} == result
|
||||
|
||||
def test_duplicate_paths(self) -> None:
|
||||
"""
|
||||
First occurrence wins for duplicate paths.
|
||||
"""
|
||||
md = "```python:a.py\nfirst\n```\n```python:a.py\nsecond\n```"
|
||||
result = split_markdown_code(md)
|
||||
assert "first" == result["a.py"]
|
||||
|
||||
|
||||
class TestExtractCodeBlockWithContext:
|
||||
"""Tests for extract_code_block_with_context."""
|
||||
|
||||
def test_basic(self) -> None:
|
||||
"""
|
||||
Code block and context are extracted.
|
||||
"""
|
||||
text = "explanation\n```python\nx = 1\n```\nafter"
|
||||
result = extract_code_block_with_context(text)
|
||||
assert result is not None
|
||||
explanation, code, after = result
|
||||
assert "explanation" == explanation
|
||||
assert "x = 1\n" == code
|
||||
assert "after" == after
|
||||
|
||||
def test_with_filepath(self) -> None:
|
||||
"""
|
||||
Filepath variant is preferred.
|
||||
"""
|
||||
text = "before\n```python:file.py\ny = 2\n```\nend"
|
||||
result = extract_code_block_with_context(text)
|
||||
assert result is not None
|
||||
assert "y = 2\n" == result[1]
|
||||
|
||||
def test_no_block(self) -> None:
|
||||
"""
|
||||
No code block returns None.
|
||||
"""
|
||||
assert extract_code_block_with_context("no code") is None
|
||||
|
||||
|
||||
class TestWrapCodeInMarkdown:
|
||||
"""Tests for wrap_code_in_markdown."""
|
||||
|
||||
def test_basic(self) -> None:
|
||||
"""
|
||||
Code is wrapped in a markdown block.
|
||||
"""
|
||||
assert "```python\nx = 1\n```" == wrap_code_in_markdown(
|
||||
"x = 1"
|
||||
)
|
||||
|
||||
|
||||
class TestGroupCode:
|
||||
"""Tests for group_code."""
|
||||
|
||||
def test_single_file(self) -> None:
|
||||
"""
|
||||
Single file is formatted as markdown block.
|
||||
"""
|
||||
result = group_code({"main.py": "x = 1\n"})
|
||||
assert "```python:main.py" in result
|
||||
assert "x = 1" in result
|
||||
|
||||
def test_multiple_files(self) -> None:
|
||||
"""
|
||||
Multiple files are joined.
|
||||
"""
|
||||
result = group_code(
|
||||
{"a.py": "x = 1\n", "b.py": "y = 2\n"}
|
||||
)
|
||||
assert "a.py" in result
|
||||
assert "b.py" in result
|
||||
|
||||
|
||||
class TestIsMultiContext:
|
||||
"""Tests for is_multi_context."""
|
||||
|
||||
def test_multi(self) -> None:
|
||||
"""
|
||||
Markdown with filepath header is multi.
|
||||
"""
|
||||
assert is_multi_context("```python:file.py\nx=1\n```")
|
||||
|
||||
def test_single(self) -> None:
|
||||
"""
|
||||
Plain code is not multi.
|
||||
"""
|
||||
assert not is_multi_context("x = 1\n")
|
||||
|
||||
|
||||
class TestHumanizeNs:
|
||||
"""Tests for _humanize_ns."""
|
||||
|
||||
def test_nanoseconds(self) -> None:
|
||||
"""
|
||||
Sub-microsecond values show ns.
|
||||
"""
|
||||
assert "500ns" == _humanize_ns(500)
|
||||
|
||||
def test_microseconds(self) -> None:
|
||||
"""
|
||||
Microsecond range shows us.
|
||||
"""
|
||||
assert "1.5us" == _humanize_ns(1500)
|
||||
|
||||
def test_milliseconds(self) -> None:
|
||||
"""
|
||||
Millisecond range shows ms.
|
||||
"""
|
||||
assert "2.5ms" == _humanize_ns(2_500_000)
|
||||
|
||||
def test_seconds(self) -> None:
|
||||
"""
|
||||
Second range shows s.
|
||||
"""
|
||||
assert "1.00s" == _humanize_ns(1_000_000_000)
|
||||
|
||||
|
||||
class TestParsePythonVersion:
|
||||
"""Tests for parse_python_version."""
|
||||
|
||||
def test_valid(self) -> None:
|
||||
"""
|
||||
Valid version is parsed to a tuple.
|
||||
"""
|
||||
assert (3, 12, 9) == parse_python_version("3.12.9")
|
||||
|
||||
def test_empty(self) -> None:
|
||||
"""
|
||||
Empty version raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="required"):
|
||||
parse_python_version("")
|
||||
|
||||
def test_wrong_format(self) -> None:
|
||||
"""
|
||||
Missing parts raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="format"):
|
||||
parse_python_version("3.12")
|
||||
|
||||
def test_too_old(self) -> None:
|
||||
"""
|
||||
Python 3.8 is rejected.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="3.9"):
|
||||
parse_python_version("3.8.0")
|
||||
|
||||
def test_python2(self) -> None:
|
||||
"""
|
||||
Python 2 is rejected.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Python 3"):
|
||||
parse_python_version("2.7.18")
|
||||
|
||||
|
||||
class TestValidateTraceId:
|
||||
"""Tests for validate_trace_id."""
|
||||
|
||||
def test_valid(self) -> None:
|
||||
"""
|
||||
Valid UUIDv4 passes.
|
||||
"""
|
||||
valid_id = str(uuid.uuid4())
|
||||
assert validate_trace_id(valid_id)
|
||||
|
||||
def test_invalid(self) -> None:
|
||||
"""
|
||||
Invalid string fails.
|
||||
"""
|
||||
assert not validate_trace_id("not-a-uuid")
|
||||
|
||||
def test_exp_suffix(self) -> None:
|
||||
"""
|
||||
EXP0 suffix is normalized and accepted.
|
||||
"""
|
||||
base = str(uuid.uuid4())
|
||||
exp_id = base[:-4] + "EXP0"
|
||||
assert validate_trace_id(exp_id)
|
||||
|
||||
|
||||
class TestCodeStrAndExplanation:
|
||||
"""Tests for CodeStrAndExplanation."""
|
||||
|
||||
def test_attributes(self) -> None:
|
||||
"""
|
||||
Code and explanation are stored.
|
||||
"""
|
||||
obj = CodeStrAndExplanation(
|
||||
code="x = 1", explanation="optimized"
|
||||
)
|
||||
assert "x = 1" == obj.code
|
||||
assert "optimized" == obj.explanation
|
||||
|
||||
|
||||
class TestBaseOptimizerContext:
|
||||
"""Tests for BaseOptimizerContext."""
|
||||
|
||||
def test_get_dynamic_context_single(self) -> None:
|
||||
"""
|
||||
Plain source code returns SingleOptimizerContext.
|
||||
"""
|
||||
ctx = BaseOptimizerContext.get_dynamic_context(
|
||||
"system", "user", "x = 1\n"
|
||||
)
|
||||
assert isinstance(ctx, SingleOptimizerContext)
|
||||
|
||||
def test_get_dynamic_context_multi(self) -> None:
|
||||
"""
|
||||
Multi-file markdown returns MultiOptimizerContext.
|
||||
"""
|
||||
source = "```python:a.py\nx = 1\n```\n```python:b.py\ny = 2\n```"
|
||||
ctx = BaseOptimizerContext.get_dynamic_context(
|
||||
"system", "user", source
|
||||
)
|
||||
assert isinstance(ctx, MultiOptimizerContext)
|
||||
|
||||
def test_get_dynamic_context_single_file_multi_format(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Single file in multi format with NO_DIFF returns Single.
|
||||
"""
|
||||
source = "```python:a.py\nx = 1\n```"
|
||||
ctx = BaseOptimizerContext.get_dynamic_context(
|
||||
"system", "user", source, DiffMethod.NO_DIFF
|
||||
)
|
||||
assert isinstance(ctx, SingleOptimizerContext)
|
||||
|
||||
def test_is_valid_code_empty(self) -> None:
|
||||
"""
|
||||
None extracted code is invalid.
|
||||
"""
|
||||
ctx = BaseOptimizerContext("s", "u", "x = 1")
|
||||
assert not ctx.is_valid_code()
|
||||
|
||||
def test_validate_and_parse_valid(self) -> None:
|
||||
"""
|
||||
Valid Python passes validation.
|
||||
"""
|
||||
ctx = BaseOptimizerContext("s", "u", "x = 1\n")
|
||||
ctx.validate_and_parse_source_code(
|
||||
"x = 1\n", feature_version=(3, 12)
|
||||
)
|
||||
|
||||
def test_validate_and_parse_empty(self) -> None:
|
||||
"""
|
||||
Empty body raises SyntaxError.
|
||||
"""
|
||||
ctx = BaseOptimizerContext("s", "u", "# comment\n")
|
||||
with pytest.raises(SyntaxError, match="empty"):
|
||||
ctx.validate_and_parse_source_code(
|
||||
"# comment\n", feature_version=(3, 12)
|
||||
)
|
||||
|
||||
|
||||
class TestSingleOptimizerContext:
|
||||
"""Tests for SingleOptimizerContext."""
|
||||
|
||||
def test_get_system_prompt(self) -> None:
|
||||
"""
|
||||
System prompt includes version and instructions.
|
||||
"""
|
||||
ctx = SingleOptimizerContext(
|
||||
"{python_version_str} {critical_instructions}",
|
||||
"user",
|
||||
"x = 1\n",
|
||||
)
|
||||
result = ctx.get_system_prompt("3.12.9")
|
||||
assert "3.12.9" in result
|
||||
assert "FULL" in result
|
||||
|
||||
def test_extract_code_from_llm_response(self) -> None:
|
||||
"""
|
||||
Code block is extracted from LLM markdown response.
|
||||
"""
|
||||
ctx = SingleOptimizerContext("s", "u", "x = 1\n")
|
||||
content = "explanation\n```python\ny = 2\n```"
|
||||
result = ctx.extract_code_and_explanation_from_llm_res(
|
||||
content
|
||||
)
|
||||
assert "y = 2\n" == result.code
|
||||
assert "explanation" == result.explanation
|
||||
|
||||
def test_parse_generates_response_item(self) -> None:
|
||||
"""
|
||||
Valid extracted code produces a response item.
|
||||
"""
|
||||
ctx = SingleOptimizerContext("s", "u", "x = 1\n")
|
||||
ctx.extract_code_and_explanation_from_llm_res(
|
||||
"faster\n```python\ny = 2\n```"
|
||||
)
|
||||
result = ctx.parse_and_generate_candidate_schema()
|
||||
assert result is not None
|
||||
assert result.source_code.strip() == "y = 2"
|
||||
|
||||
def test_parse_returns_none_for_empty(self) -> None:
|
||||
"""
|
||||
Empty extracted code returns None.
|
||||
"""
|
||||
ctx = SingleOptimizerContext("s", "u", "x = 1\n")
|
||||
ctx.extracted_code_and_expl = CodeStrAndExplanation(
|
||||
code="", explanation=""
|
||||
)
|
||||
assert ctx.parse_and_generate_candidate_schema() is None
|
||||
|
||||
|
||||
class TestMultiOptimizerContext:
|
||||
"""Tests for MultiOptimizerContext."""
|
||||
|
||||
def test_get_system_prompt_no_diff(self) -> None:
|
||||
"""
|
||||
NO_DIFF uses markdown format prompt.
|
||||
"""
|
||||
source = "```python:a.py\nx = 1\n```\n```python:b.py\ny = 2\n```"
|
||||
ctx = MultiOptimizerContext(
|
||||
"{python_version_str} {critical_instructions}",
|
||||
"user",
|
||||
source,
|
||||
diff_method=DiffMethod.NO_DIFF,
|
||||
)
|
||||
result = ctx.get_system_prompt("3.12.9")
|
||||
assert "3.12.9" in result
|
||||
assert "Code Input format" in result
|
||||
|
||||
def test_is_valid_code_rejects_new_files(self) -> None:
|
||||
"""
|
||||
Extracted files not in original are rejected.
|
||||
"""
|
||||
source = "```python:a.py\nx = 1\n```"
|
||||
ctx = MultiOptimizerContext(
|
||||
"s",
|
||||
"u",
|
||||
source,
|
||||
diff_method=DiffMethod.NO_DIFF,
|
||||
)
|
||||
ctx.extracted_code_and_expl = CodeStrAndExplanation(
|
||||
code="```python:new_file.py\nz = 3\n```",
|
||||
explanation="added",
|
||||
)
|
||||
assert not ctx.is_valid_code()
|
||||
|
||||
|
||||
class TestGetModelDistribution:
|
||||
"""Tests for get_model_distribution."""
|
||||
|
||||
def test_default_five(self) -> None:
|
||||
"""
|
||||
5 candidates: 3 GPT + 2 Claude.
|
||||
"""
|
||||
dist = get_model_distribution(5, MAX_OPTIMIZER_CALLS)
|
||||
assert 2 == len(dist)
|
||||
gpt_model, gpt_calls = dist[0]
|
||||
claude_model, claude_calls = dist[1]
|
||||
assert 3 == gpt_calls
|
||||
assert 2 == claude_calls
|
||||
|
||||
def test_capped(self) -> None:
|
||||
"""
|
||||
Candidates are capped at max_calls.
|
||||
"""
|
||||
dist = get_model_distribution(100, MAX_OPTIMIZER_CALLS)
|
||||
total = sum(calls for _, calls in dist)
|
||||
assert MAX_OPTIMIZER_CALLS == total
|
||||
|
||||
def test_single(self) -> None:
|
||||
"""
|
||||
Single candidate: 1 GPT + 0 Claude.
|
||||
"""
|
||||
dist = get_model_distribution(1, MAX_OPTIMIZER_CALLS)
|
||||
total = sum(calls for _, calls in dist)
|
||||
assert 1 == total
|
||||
|
||||
|
||||
class TestOptimizeSchemas:
|
||||
"""Tests for Pydantic optimize schemas."""
|
||||
|
||||
def test_request_defaults(self) -> None:
|
||||
"""
|
||||
Default fields are set correctly.
|
||||
"""
|
||||
req = OptimizeRequest(
|
||||
source_code="x = 1",
|
||||
trace_id=str(uuid.uuid4()),
|
||||
)
|
||||
assert 5 == req.n_candidates
|
||||
assert "python" == req.language
|
||||
assert req.is_async is False
|
||||
|
||||
def test_request_version_resolution(self) -> None:
|
||||
"""
|
||||
language_version backfills python_version.
|
||||
"""
|
||||
req = OptimizeRequest(
|
||||
source_code="x = 1",
|
||||
trace_id=str(uuid.uuid4()),
|
||||
language_version="3.12.9",
|
||||
)
|
||||
assert "3.12.9" == req.python_version
|
||||
|
||||
def test_response_item(self) -> None:
|
||||
"""
|
||||
Response item is serializable.
|
||||
"""
|
||||
item = OptimizeResponseItem(
|
||||
source_code="x = 2",
|
||||
explanation="faster",
|
||||
optimization_id="abc",
|
||||
)
|
||||
data = item.model_dump()
|
||||
assert "x = 2" == data["source_code"]
|
||||
assert data["parent_id"] is None
|
||||
|
||||
def test_response(self) -> None:
|
||||
"""
|
||||
Response wraps items.
|
||||
"""
|
||||
resp = OptimizeResponse(
|
||||
optimizations=[
|
||||
OptimizeResponseItem(
|
||||
source_code="x = 2",
|
||||
explanation="faster",
|
||||
optimization_id="abc",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert 1 == len(resp.optimizations)
|
||||
|
||||
def test_error_response(self) -> None:
|
||||
"""
|
||||
Error response has error field.
|
||||
"""
|
||||
err = OptimizeErrorResponse(error="bad")
|
||||
assert "bad" == err.error
|
||||
|
||||
def test_candidate_source_enum(self) -> None:
|
||||
"""
|
||||
Enum values are accessible.
|
||||
"""
|
||||
assert "OPTIMIZE" == OptimizedCandidateSource.OPTIMIZE.value
|
||||
|
|
@ -103,6 +103,15 @@ ignore = [
|
|||
"PLC0415", # conditional imports for event loop safety (clients recreated on loop change)
|
||||
"TRY301", # raise inside try is the intended pattern for cost-tracking on unsupported model type
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/optimize/_context.py" = [
|
||||
"PLR2004", # magic values in faithfully ported version parsing and humanize_ns
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/optimize/_pipeline.py" = [
|
||||
"PLR0913", # faithfully ported signatures require many args for LLM call orchestration
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/_app.py" = [
|
||||
"PLC0415", # local import for router registration avoids circular imports
|
||||
]
|
||||
"packages/codeflash-core/src/codeflash_core/_model.py" = [
|
||||
"C901", # humanize_runtime is complex but faithfully ported
|
||||
"PLR2004", # magic values in humanize_runtime thresholds
|
||||
|
|
|
|||
Loading…
Reference in a new issue