Add optimize endpoint: context, pipeline, router, prompt templates

Faithful port of the Python optimization pipeline from Django aiservice:
- schemas.py: Pydantic request/response models (OptimizeRequest, OptimizeResponse)
- _markdown.py: markdown code block extraction, splitting, grouping
- _context.py: BaseOptimizerContext with Single/Multi variants for
  prompt assembly, LLM response extraction, and postprocessing
- _pipeline.py: parallel LLM orchestration with model distribution
  (GPT-5-mini + Claude Sonnet 4.5), diversity via line profiler toggling
- _router.py: POST /ai/optimize with auth, rate limiting, usage tracking
- 11 prompt templates copied verbatim from Django reference
- LLM client wired into app lifespan
This commit is contained in:
Kevin Turcios 2026-04-21 22:16:22 -05:00
parent 3e62f502e7
commit 6c04324e25
19 changed files with 2312 additions and 0 deletions

View file

@ -12,6 +12,7 @@ from sentry_sdk.integrations.fastapi import FastApiIntegration
from codeflash_api._config import settings
from codeflash_api.db._engine import create_pool
from codeflash_api.db._queries import Queries
from codeflash_api.llm._client import LLMClient
if TYPE_CHECKING:
from collections.abc import AsyncIterator
@ -28,6 +29,7 @@ async def _lifespan(app: FastAPI) -> AsyncIterator[dict[str, object]]:
app.state.queries = queries
app.state.rate_limit_cache = {}
app.state.llm_client = LLMClient()
state: dict[str, object] = {}
yield state
@ -65,6 +67,10 @@ def create_app() -> FastAPI:
def _register_routes(app: FastAPI) -> None:
from codeflash_api.optimize._router import router as optimize_router
app.include_router(optimize_router)
@app.get("/healthcheck")
async def healthcheck() -> dict[str, str]:
return {"status": "ok"}

View file

@ -0,0 +1,135 @@
from __future__ import annotations
import re
MARKDOWN_CODE_BLOCK_PATTERN = re.compile(
r"```python(?::[^\n]*)?\n(.*?)```", re.DOTALL
)
FIRST_CODE_BLOCK_PATTERN = re.compile(
r"^```python(?::[^\n]*)?\s*\n(.*)\n```[ \t]*$",
re.MULTILINE | re.DOTALL,
)
FIRST_CODE_BLOCK_FALLBACK_PATTERN = re.compile(
r"^```python(?::[^\n]*)?\s*\n(.*)",
re.MULTILINE | re.DOTALL,
)
_PATHOLOGICAL_PATTERN = re.compile(r"(\\['\"]{10,})")
def truncate_pathological_output(code: str) -> str:
"""
Truncate code at pathological repeated escape sequences.
"""
if match := _PATHOLOGICAL_PATTERN.search(code):
return code[: match.start()].rstrip()
return code
def extract_all_code_from_markdown(markdown: str) -> str:
"""
Extract all code from markdown code blocks.
"""
matches = MARKDOWN_CODE_BLOCK_PATTERN.findall(markdown)
return "\n\n".join(matches)
def extract_code_block(markdown: str) -> str | None:
"""
Extract the first code block from markdown.
"""
if match := FIRST_CODE_BLOCK_PATTERN.search(markdown):
return truncate_pathological_output(match.group(1))
if match := FIRST_CODE_BLOCK_FALLBACK_PATTERN.search(markdown):
return truncate_pathological_output(match.group(1))
return None
def split_markdown_code(
markdown: str, language: str = "python"
) -> dict[str, str]:
"""
Split markdown into a dict of filepath to code content.
"""
if language in ("javascript", "js"):
lang_pattern = r"(?:javascript|js)"
elif language in ("typescript", "ts"):
lang_pattern = r"(?:typescript|ts)"
else:
lang_pattern = re.escape(language)
pattern = re.compile(
rf"```{lang_pattern}:([^\n]+)\n(.*?)\n```", re.DOTALL
)
matches = pattern.findall(markdown)
result: dict[str, str] = {}
for file_path, code in matches:
stripped_path = file_path.strip()
if stripped_path not in result:
result[stripped_path] = code
return result
def extract_code_block_with_context(
text: str, language: str = "python"
) -> tuple[str, str, str] | None:
"""
Extract a code block and its surrounding context.
"""
pattern_with_path = (
rf"(.*?)```{language}:[^\n]+(?:\n|\\n)(.*?)```(.*)"
)
if match := re.match(
pattern_with_path, text, re.DOTALL | re.MULTILINE
):
return (
match.group(1).strip(),
match.group(2),
match.group(3).strip(),
)
pattern = (
rf"(.*?)```{language}(?::[^\n]*)?(?:\n|\\n)(.*?)```(.*)"
)
if match := re.match(pattern, text, re.DOTALL | re.MULTILINE):
return (
match.group(1).strip(),
match.group(2),
match.group(3).strip(),
)
return None
def wrap_code_in_markdown(
code: str, language: str = "python"
) -> str:
"""
Wrap code in a markdown code block.
"""
return f"```{language}\n{code}\n```"
def group_code(
file_to_code: dict[str, str], language: str = "python"
) -> str:
"""
Join file-to-code mapping into grouped markdown blocks.
"""
def format_block(file_path: str, code: str) -> str:
normalized = code if code.endswith("\n") else code + "\n"
return f"```{language}:{file_path}\n{normalized}```"
return "\n".join(
format_block(path, code)
for path, code in file_to_code.items()
)
def is_multi_context(code: str) -> bool:
"""
Check if code is in multi-file markdown format.
"""
return code.strip().startswith("```python:")

View file

@ -0,0 +1,63 @@
You are a professional computer programmer who specializes in writing high-performance **asynchronous** Python code. Your goal is to optimize the runtime and memory efficiency of the provided **async** code through safe and meaningful rewrites that would pass senior-level code review.
**CRITICAL: ASYNC CODE REQUIREMENTS**
- The code contains **async functions** that must remain async
- ALL async functions must maintain their `async def` signature
- ALL `await` expressions must be preserved where they exist
- Do NOT convert async functions to synchronous functions
- Do NOT remove `await` keywords unless replacing with functionally equivalent async operations
- Preserve concurrency patterns and async context manager usage
- Maintain proper async/await flow and error handling in async contexts
**Behavioral Preservation (CRITICAL)**
- Do NOT rename functions or change their signatures.
- You MUST NOT change the behavior, return values, side effects, printed/logged output, or raised exceptions - they MUST remain exactly the same.
- Do NOT mutate inputs in a different way than the original implementation.
- The same exception types should be raised in the same circumstances.
- Preserve existing type annotations - all function parameters, return types, and variable annotations must be preserved exactly as written.
- **Preserve the original code style**: Keep existing variable names unless the logic fundamentally changes
- Preserve ALL existing comments exactly as written, unless the corresponding code logic is changed or the comment becomes factually incorrect
- Avoid excessive inline comments - only add new comments for significant or non-obvious logic changes
**Async-Specific Optimization Focus**
- Optimize async patterns such as concurrent execution with `asyncio.gather()` or `asyncio.create_task()`
- Consider using `asyncio.as_completed()` for better performance when appropriate
- Identify and replace blocking operations with async equivalents (e.g., `time.sleep()``asyncio.sleep()`, sync file I/O → `aiofiles`, blocking network calls → async libraries)
- Optimize async context managers and async iterators
- Improve async I/O operations and resource management
- Consider async comprehensions where they provide performance benefits
- Use `asyncio.to_thread()` for CPU-intensive tasks that would block the event loop
- Maintain proper async exception handling and cleanup
**Code Style & Structure**
- Do NOT replace walrus operators (`:=`) for optimization purposes.
- Keep `assert` statements as-is - do NOT convert them to `if/raise AssertionError` patterns, it doesn't improve the performance.
- **DO NOT convert `isinstance()` checks to `type()` checks**. `isinstance()` correctly handles inheritance and subclasses, while `type()` checks are incorrect for subclass instances and represent a micro-optimization that should be avoided.
- You may write new async helper functions that do not already exist in the codebase.
- Avoid purely stylistic changes unless they result in noticeable performance improvements
- Ensure all new async code follows proper async patterns and conventions
**Optimization Focus**
- Create production-ready async code that professional programmers would merge without further edits
- Prioritize changes that provide measurable runtime or memory efficiency gains in async contexts
- Consider async-specific performance patterns like batching operations or reducing context switching
**Code Quality Standards**
- Ensure all async optimizations are safe and would pass senior-level code review
- Maintain code readability and maintainability alongside performance improvements
- Verify that async operations are properly awaited and handled
**Response Format (REQUIRED)**
- ALWAYS start your response with a brief explanation (2-4 sentences) of what optimization you made and why it improves performance
- Then provide the optimized code in a markdown code block
- Example format:
```
**Optimization Explanation:**
[Your explanation here describing the optimization technique and expected performance improvement]
```python:filename.py
[optimized code]
```
```
The current Python version is {python_version_str}

View file

@ -0,0 +1,19 @@
Rewrite this **asynchronous** Python program to run faster while preserving all async behavior.
**CRITICAL ASYNC REQUIREMENTS:**
- The code contains **async functions** - you MUST keep them async
- ALL `async def` function signatures must be preserved exactly
- ALL `await` expressions must be maintained (unless replaced with functionally equivalent async operations)
- Do NOT convert async functions to synchronous functions
- Preserve concurrent execution patterns and async context managers
- Maintain proper async/await flow and exception handling
**Async Optimization Guidelines:**
- Consider using `asyncio.gather()` for concurrent execution when beneficial
- Replace blocking operations with async equivalents (e.g., `time.sleep()` with `asyncio.sleep()`, sync I/O with async libraries)
- Optimize async I/O operations and batching where appropriate
- Use `asyncio.to_thread()` for CPU-intensive tasks to avoid blocking the event loop
- Optimize async context managers and async iterators
- Maintain async exception handling and resource cleanup
{source_code}

View file

@ -0,0 +1,5 @@
Here are some READ ONLY Python module code snippets for understanding internal codebase dependencies of the Python program you will be optimizing, provided for your reference.
Do not edit these Python modules, treat them as *read-only*.
{dependency_code}

View file

@ -0,0 +1 @@
When making modifications to the `__init__` function, do not change the values and behavior of any existing instance attributes. You may only add new attributes.

View file

@ -0,0 +1,146 @@
# JIT Compilation Optimization Guidelines
## Decision Logic (Framework Router)
Before rewriting, analyze the imports and data structures to select the correct backend:
- IF the code uses `torch.nn` or `torch.Tensor` → Use **PyTorch**.
- IF the code uses `tensorflow`, `keras`, or `tf.Tensor` → Use **TensorFlow**.
- IF the code uses `jax.numpy` or `jax.grad` → Use **JAX**.
- IF the code uses standard `numpy` arrays and Python loops → Use **Numba**.
- IF the code relies heavily on un-compilable Python features (pandas, dynamic dicts, strings) → **Return Original Code**.
## When NOT to Use JIT Compilation
Do NOT generate JIT-compiled candidates in the following scenarios. Instead, focus on other optimization strategies:
### 1. I/O-Bound Functions
Functions that primarily perform file operations, network requests, database queries, or other I/O operations will not benefit from JIT compilation. The bottleneck is external latency, not CPU computation.
### 2. Functions Called Infrequently
JIT compilation has an upfront compilation cost. If a function is only called once or very rarely, the compilation overhead will exceed any runtime savings. JIT is best for hot paths called many times.
### 3. Heavy Use of Unsupported Python Features
Numba's nopython mode does not support:
- Dictionary comprehensions or complex dict operations
- String manipulation (except basic operations)
- List comprehensions with complex logic
- Most Python standard library functions
- Custom Python classes and objects
- Exception handling with complex try/except blocks
- Generators and iterators (limited support)
- Recursion with varying depths (can cause issues)
If the function relies heavily on these features, JIT compilation will either fail or fall back to slow object mode.
### 4. Functions with Dynamic or Heterogeneous Types
JIT compilers optimize for specific type signatures. Functions that receive different types on each call will trigger recompilation, negating performance benefits.
### 5. Small/Trivial Functions
Functions with minimal computation (e.g., simple getters, one-line calculations, basic conditionals) have negligible runtime. JIT overhead is not justified.
### 6. Code Using Incompatible Libraries
Many Python libraries are not JIT-compatible:
- pandas operations (use vectorized pandas methods instead)
- Most third-party libraries
- Complex numpy operations with object dtypes
- Symbolic math libraries (sympy, etc.)
### 7. Functions with Heavy Object Creation
Code that creates many Python objects, uses class instances extensively, or relies on Python's dynamic nature will not JIT well.
## JIT Output Format (CRITICAL)
When you determine that JIT compilation is viable for the given code:
- Apply the JIT decorator **directly** to the output function (e.g., `@numba.njit`, `@torch.compile`, `@tf.function`, `@jax.jit`).
- Add the necessary import (e.g., `import numba`) at the top of the code.
- Do **NOT** create conditional fallback paths. Never wrap JIT usage in `try/except ImportError`, `if HAS_NUMBA`, or any similar if/else branching that falls back to a non-JIT version.
- Do **NOT** create a separate "fast" helper function alongside the original. The output must be a single, clean function with the JIT decorator applied.
- If JIT is not viable for this code, optimize it using other strategies without any JIT decorators — do not include a JIT path at all.
## Guidelines for Numba (`@njit`)
Use Numba when the code:
- Performs numerical computations with NumPy arrays
- Contains loops over array elements
- Uses basic NumPy operations (arithmetic, indexing, slicing)
- Does not use unsupported Python features (classes, dictionaries with non-scalar keys, etc.)
### Numba Best Practices:
- Add `@numba.njit` or `@numba.njit(cache=True)` decorator
- Use `numba.prange` instead of `range` for parallel loops when appropriate
- Avoid Python objects inside JIT functions (use primitive types and arrays)
- Pre-allocate output arrays instead of using list.append()
- Do not use `fastmath=True`
- Ensure all array dtypes are consistent
### Numba Limitations to Avoid:
- No dictionary comprehensions or set operations
- No string operations
- No class instances (use structured arrays instead)
- Limited support for keyword arguments in nested calls
## Guidelines for PyTorch (`torch.compile`)
Use torch.compile when the code:
- Defines or uses PyTorch neural network modules (nn.Module)
- Performs tensor operations with torch tensors
- Contains forward passes or training loops
### torch.compile Best Practices:
- Apply `@torch.compile()` to the model or specific functions
- Use `mode="reduce-overhead"` for small models with many calls
- Use `mode="max-autotune"` for maximum performance when compile time is acceptable
- Consider `fullgraph=True` to ensure the entire function is compiled
- Use `dynamic=True` if input shapes vary
### torch.compile Limitations to Avoid:
- No Python print statements inside compiled functions (causes graph breaks)
- Avoid data-dependent control flow (if statements based on tensor values)
- No unsupported Python built-ins (input(), eval(), exec())
- Avoid in-place operations on input tensors when using certain backends
- No operations with data-dependent output shapes (use `dynamic=True` if needed)
- Avoid calling non-PyTorch functions that can't be traced (use `torch.compiler.disable` to skip them)
- No Python generators or iterators over tensors
- Avoid global variable modifications inside compiled functions
- Limited support for custom autograd functions (may cause graph breaks)
- No third-party libraries that aren't torch-compatible inside the compiled region
## Guidelines for TensorFlow (`tf.function`)
Use tf.function when the code:
- Uses TensorFlow tensors and operations
- Defines or uses Keras models (tf.keras.Model)
- Contains training loops or inference pipelines
- Performs tensor computations with tf.* operations
### tf.function Best Practices:
- Add `@tf.function` decorator to functions performing tensor operations
- Use `jit_compile=True` for XLA compilation: `@tf.function(jit_compile=True)`
- Specify `input_signature` for fixed input shapes to avoid retracing
- Use `experimental_relax_shapes=True` if input shapes vary slightly
- Avoid Python side effects inside tf.function (print, list.append, etc.)
- Use tf.print() instead of print() for debugging inside traced functions
### tf.function Limitations to Avoid:
- No Python control flow that depends on tensor values (use tf.cond, tf.while_loop instead)
- Avoid creating variables inside tf.function (create them outside or in __init__)
- No unsupported Python operations (file I/O, random.random(), etc.)
- Minimize Python object creation inside the function
## Guidelines for JAX (`jax.jit`)
Use jax.jit when the code:
- Uses JAX arrays (jax.numpy operations)
- Performs functional numerical computations
- Contains pure functions without side effects
- Uses JAX transformations (grad, vmap, pmap)
### jax.jit Best Practices:
- Add `@jax.jit` decorator to pure functions
- Use `static_argnums` for arguments that should not be traced (e.g., axis parameters, shapes)
- Use `static_argnames` for keyword arguments that should be static
- Prefer `jax.numpy` over `numpy` for array operations inside jitted functions
- Use `donate_argnums` to allow JAX to reuse input buffers for outputs
- Combine with other JAX transforms: `jax.jit(jax.vmap(fn))` for batched operations
### jax.jit Limitations to Avoid:
- No in-place mutations (JAX arrays are immutable)
- No Python side effects (print, file I/O, global variable modifications)
- No data-dependent control flow with dynamic shapes (use jax.lax.cond, jax.lax.fori_loop instead)
- Avoid Python loops over traced values (use jax.lax.scan or jax.vmap)
- No NumPy arrays inside jitted functions (convert to JAX arrays first)

View file

@ -0,0 +1,2 @@
Here are the results of the line profiling of the Python program you will be optimizing, provided for your reference.
{line_profiler_results}

View file

@ -0,0 +1,35 @@
### Code Input format
You will receive Markdown code blocks containing multiple Python file segments. Each file is represented as a separate Markdown code block with the format ```python:<path_of_the_file>.
Example:
```python:main/app1.py
def say_hello_from_app1():
print("Hello")
```
```python:main/app2.py
def say_hello_from_app2():
print("Hello")
```
**IMPORTANT RULES: (CRITICAL)**
- Each Markdown code block represents a **separate Python module**
- The **first file** contains the function to optimize. The remaining files are **context only** — they show helper functions and dependencies that the target function uses. Do NOT optimize or return the context files unless you actually modified them.
- You MUST preserve all file paths in their exact original format and maintain the same order
- You are NOT allowed to create new files that did not exist in the provided code
- Treat imports and dependencies between files appropriately (e.g., if app2.py imports from app1.py)
- Maintain the exact same order of files as provided in the input
- Each file should remain as a cohesive unit.
- Only return files you actually modified — you do NOT need to return unchanged context files
### Explanation Style:
Keep explanations **developer-focused and concise**. Focus on:
- **What** specific optimizations were applied.
- **Performance impact** when significant (e.g., "Reduced time complexity from O(n²) to O(n)")
- **Key changes** that affect behavior or dependencies
- Avoid mentioning obvious preservation details (file structure, imports, signatures) unless they were specifically modified
Your response **MUST** contain:
1. A **brief, technical explanation** of optimizations applied
2. **One or more Markdown code block(s)** for the files you modified, preserving their original file paths
**Any deviation from this format is incorrect. {EXPLANATION_THEN_CODE}**

View file

@ -0,0 +1,2 @@
The current measured runtime of this function is {baseline_runtime_ns} nanoseconds ({baseline_runtime_human}) over {loop_count} benchmark loops. Focus your optimization on changes that will meaningfully reduce this runtime.

View file

@ -0,0 +1,43 @@
You are a professional computer programmer who specializes in writing high-performance Python code. Your goal is to optimize the runtime and memory efficiency of the provided code through safe and meaningful rewrites that would pass senior-level code review.
**Behavioral Preservation (CRITICAL)**
- Do NOT rename functions or change their signatures.
- You MUST NOT change the behavior, return values, side effects, printed/logged output, or raised exceptions - they MUST remain exactly the same.
- Do NOT mutate inputs in a different way than the original implementation.
- The same exception types should be raised in the same circumstances.
- Preserve existing type annotations - all function parameters, return types, and variable annotations must be preserved exactly as written.
- **Preserve the original code style**: Keep existing variable names unless the logic fundamentally changes
- Preserve ALL existing comments exactly as written, unless the corresponding code logic is changed or the comment becomes factually incorrect
- Avoid excessive inline comments - only add new comments for significant or non-obvious logic changes
{critical_instructions}
**Code Style & Structure**
- Do NOT replace walrus operators (`:=`) for optimization purposes.
- DO NOT introduce attribute lookup optimizations. The performance improvements are minimal and come at a substantial cost to readability.
- Keep `assert` statements as-is - do NOT convert them to `if/raise AssertionError` patterns, it doesn't improve the performance.
- **DO NOT convert `isinstance()` checks to `type()` checks**. `isinstance()` correctly handles inheritance and subclasses, while `type()` checks are incorrect for subclass instances and represent a micro-optimization that should be avoided.
- You may write new helper functions that do not already exist in the codebase.
- Avoid purely stylistic changes unless they result in noticeable performance improvements
**Optimization Focus**
- Create production-ready code that professional programmers would merge without further edits
- Prioritize changes that provide measurable runtime or memory efficiency gains
**Code Quality Standards**
- Ensure all optimizations are safe and would pass senior-level code review
- Maintain code readability and maintainability alongside performance improvements
**Response Format (REQUIRED)**
- ALWAYS start your response with a brief explanation (2-4 sentences) of what optimization you made and why it improves performance
- Then provide the optimized code in a markdown code block
- Example format:
```
**Optimization Explanation:**
[Your explanation here describing the optimization technique and expected performance improvement]
```python:filename.py
[optimized code]
```
```
The current Python version is {python_version_str}

View file

@ -0,0 +1,6 @@
Here are tests that define the expected behavior of the function you are optimizing. Your optimization MUST produce identical results for all these test cases.
Pay special attention to hand-written unit tests — they encode the developer's explicit behavioral expectations and edge cases. Any optimization that changes the output for these inputs is incorrect.
{test_input_examples}

View file

@ -0,0 +1,3 @@
Rewrite this python program to run faster.
{source_code}

View file

@ -0,0 +1,775 @@
from __future__ import annotations
import ast
import logging
import re
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import libcst as cst
from codeflash_api.diff._base import DiffMethod
from codeflash_api.diff._search_replace import SearchAndReplaceDiff
from codeflash_api.diff._v4a import V4ADiff
from codeflash_api.languages.python._cst_utils import (
find_init,
parse_module_to_cst,
)
from codeflash_api.languages.python._markdown import (
extract_code_block_with_context,
group_code,
is_multi_context,
split_markdown_code,
wrap_code_in_markdown,
)
from codeflash_api.languages.python._postprocess import (
OptimizationCandidate,
optimizations_postprocessing_pipeline,
)
from codeflash_api.optimize.schemas import OptimizeResponseItem
if TYPE_CHECKING:
from codeflash_api.diff._base import Diff
log = logging.getLogger(__name__)
_PROMPTS_DIR = Path(__file__).parent.parent / "languages" / "python" / "prompts"
EXPLANATION_THEN_CODE = (
"Begin your response with a short explanation of the changes"
" (without any title or heading like 'Explanation' or 'Changes'),"
" then output the code."
)
FULL_CODE_PROMPT_INSTRUCTIONS = (
"- Always provide the FULL, updated content of the artifact."
" This means:\n"
" - Include ALL updated code, even if some parts are"
" unchanged\n"
' - NEVER use placeholders like "// rest of the code'
' remains the same..." or "<- leave original code here ->"\n'
)
MARKDOWN_CONTEXT_PROMPT = (
(_PROMPTS_DIR / "multi_file_code_format.md")
.read_text()
.format(EXPLANATION_THEN_CODE=EXPLANATION_THEN_CODE)
)
DEPS_CONTEXT_PROMPT = (
(_PROMPTS_DIR / "dependency_context_prompt.md").read_text()
)
INIT_OPTIMIZATION_PROMPT = (
(_PROMPTS_DIR / "init_optimization_prompt.md").read_text()
)
LINE_PROF_CONTEXT_PROMPT = (
(_PROMPTS_DIR / "lineprof_context_prompt.md").read_text()
)
RUNTIME_CONTEXT_PROMPT = (
(_PROMPTS_DIR / "runtime_context_prompt.md").read_text()
)
TEST_EXAMPLES_PROMPT = (
(_PROMPTS_DIR / "test_examples_prompt.md").read_text()
)
V4A_DIFF_FORMAT_PROMPT = """Describe the changes using the V4A diff format, enclosed within `*** Begin Patch` and `*** End Patch` markers.
# V4A Diff Format Rules:
Your entire response containing the patch MUST start with `*** Begin Patch` on a line by itself.
Your entire response containing the patch MUST end with `*** End Patch` on a line by itself.
Use the *FULL* file path, as shown to you by the user.
For each file you need to modify, start with a marker line:
*** Update File: [path/to/file]
You ONLY update existing files, do not add or remove files.
Each file MUST appear only once in the patch.
Consolidate all changes for that file into the same block.
For `Update` actions, describe each snippet of code that needs to be changed using the following format:
1. Context lines: Include 3 lines of context *before* the change. These lines MUST start with a single space ` `.
2. Lines to remove: Precede each line to be removed with a minus sign `-`.
3. Lines to add: Precede each line to be added with a plus sign `+`.
4. Context lines: Include 3 lines of context *after* the change. These lines MUST start with a single space ` `.
Context lines MUST exactly match the existing file content, character for character, including indentation.
If a change is near the beginning or end of the file, include fewer than 3 context lines as appropriate.
If 3 lines of context is insufficient to uniquely identify the snippet, use `@@ [CLASS_OR_FUNCTION_NAME]` markers on their own lines *before* the context lines to specify the scope.
Do not include line numbers.
Only create patches for files that the user has provided.
ONLY EVER RETURN CODE IN THE SPECIFIED V4A DIFF FORMAT, followed by a short explanation of the changes.
"""
SEARCH_AND_REPLACE_FORMAT_PROMPT = """Describe the changes using the search/replace diff format with xml-style tags, like:
<replace_in_file>
<path>src/main.py</path>
<diff>
<<<<<<< SEARCH
a = 2
=======
a = 3
>>>>>>> REPLACE
</diff>
</replace_in_file>
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
Parameters:
- path: (required) The path of the file to modify
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
```
<<<<<<< SEARCH
[exact content to find]
=======
[new content to replace with]
>>>>>>> REPLACE
```
Critical rules:
1. SEARCH content must match the associated file section to find EXACTLY.
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
3. Keep SEARCH/REPLACE blocks concise.
4. Special operations:
* To move code: Use two SEARCH/REPLACE blocks
* To delete code: Use empty REPLACE section
"""
def _humanize_ns(ns: int) -> str:
"""
Convert nanoseconds to a human-readable string.
"""
if ns < 1_000:
return f"{ns}ns"
if ns < 1_000_000:
return f"{ns / 1_000:.1f}us"
if ns < 1_000_000_000:
return f"{ns / 1_000_000:.1f}ms"
return f"{ns / 1_000_000_000:.2f}s"
def parse_python_version(version: str | None) -> tuple[int, int, int]:
"""
Parse a version string like '3.12.9' into a tuple.
"""
if not version:
msg = "Python version is required"
raise ValueError(msg)
if len(version) > 30:
msg = "Invalid version format"
raise ValueError(msg)
parts = version.split(".")
if len(parts) != 3:
msg = "Invalid version format"
raise ValueError(msg)
patch_match = re.match(r"\d+", parts[2])
if not patch_match:
msg = "Invalid patch version"
raise ValueError(msg)
major = int(parts[0])
minor = int(parts[1])
patch = int(patch_match.group())
if major != 3:
msg = "Only Python 3 is supported"
raise ValueError(msg)
if minor < 9 or minor > 15:
msg = "Only Python 3.9 and above is supported"
raise ValueError(msg)
if patch < 0 or patch >= 100:
msg = "Invalid version format"
raise ValueError(msg)
return (major, minor, patch)
def validate_trace_id(trace_id: str) -> bool:
"""
Check that *trace_id* is a valid UUIDv4.
"""
normalized = trace_id
if trace_id[-4:] in ("EXP0", "EXP1"):
normalized = trace_id[:-4] + "0000"
try:
uuid_obj = uuid.UUID(normalized, version=4)
return str(uuid_obj) == normalized
except (ValueError, AttributeError):
return False
class CodeStrAndExplanation:
"""
Extracted code and explanation from LLM output.
"""
__slots__ = ("code", "explanation")
def __init__(self, code: str, explanation: str) -> None:
self.code = code
self.explanation = explanation
class BaseOptimizerContext:
"""
Base context for Python optimization.
"""
def __init__(
self,
base_system_prompt: str,
base_user_prompt: str,
source_code: str,
) -> None:
self.base_system_prompt = base_system_prompt
self.base_user_prompt = base_user_prompt
self.source_code = source_code
self.extracted_code_and_expl: CodeStrAndExplanation | None = None
self.code_and_explanation_before_post_processing: dict[
str, CodeStrAndExplanation
] = {}
@staticmethod
def get_dynamic_context(
system_prompt: str,
user_prompt: str,
source_code: str,
diff_method: DiffMethod = DiffMethod.NO_DIFF,
) -> BaseOptimizerContext:
"""
Factory: choose Single or Multi based on source format.
"""
if is_multi_context(source_code):
file_to_code = split_markdown_code(source_code)
files = list(file_to_code.keys())
if len(files) == 1 and diff_method == DiffMethod.NO_DIFF:
file_name = files[0]
return SingleOptimizerContext(
system_prompt,
user_prompt,
source_code=file_to_code[file_name],
file_name=file_name,
)
return MultiOptimizerContext(
system_prompt,
user_prompt,
source_code,
diff_method=diff_method,
)
return SingleOptimizerContext(
system_prompt, user_prompt, source_code
)
def get_system_prompt(
self, python_version_str: str
) -> str:
"""
Return the formatted system prompt.
"""
return self.base_system_prompt
def get_user_prompt(
self,
dependency_code: str,
line_profiler_results: str | None,
*,
baseline_runtime_ns: int | None = None,
loop_count: int | None = None,
test_input_examples: str | None = None,
) -> str:
"""
Return the formatted user prompt.
"""
return self.base_user_prompt
def extract_code_and_explanation_from_llm_res(
self, content: str
) -> CodeStrAndExplanation:
"""
Parse LLM output into code and explanation.
"""
raise NotImplementedError
def parse_and_generate_candidate_schema(
self,
) -> OptimizeResponseItem | None:
"""
Post-process extracted code into a response item.
"""
raise NotImplementedError
def is_valid_code(self) -> bool:
"""
Check if extracted code is non-empty.
"""
if self.extracted_code_and_expl is None:
return False
code = self.extracted_code_and_expl.code
return code is not None and code.strip() != ""
def validate_and_parse_source_code(
self,
code: str,
feature_version: tuple[int, ...],
) -> None:
"""
Validate that *code* is syntactically valid Python.
"""
final_code = code or self.source_code
parsed = ast.parse(final_code, feature_version=feature_version)
if not parsed.body:
msg = "Source code is empty or invalid"
raise SyntaxError(msg)
class SingleOptimizerContext(BaseOptimizerContext):
"""
Context for single-file optimizations.
"""
def __init__(
self,
base_system_prompt: str,
base_user_prompt: str,
source_code: str,
file_name: str | None = None,
) -> None:
super().__init__(
base_system_prompt, base_user_prompt, source_code
)
self.file_name = file_name
def get_system_prompt(
self, python_version_str: str
) -> str:
"""
Format system prompt with full-code instructions.
"""
return (
self.base_system_prompt.format(
python_version_str=python_version_str,
critical_instructions=FULL_CODE_PROMPT_INSTRUCTIONS,
)
+ "\n"
+ EXPLANATION_THEN_CODE
)
def get_user_prompt(
self,
dependency_code: str,
line_profiler_results: str | None,
*,
baseline_runtime_ns: int | None = None,
loop_count: int | None = None,
test_input_examples: str | None = None,
) -> str:
"""
Build user prompt with context sections.
"""
markdown_source_code = wrap_code_in_markdown(self.source_code)
runtime_part = ""
if (
baseline_runtime_ns is not None
and loop_count is not None
):
runtime_part = RUNTIME_CONTEXT_PROMPT.format(
baseline_runtime_ns=baseline_runtime_ns,
baseline_runtime_human=_humanize_ns(
baseline_runtime_ns
),
loop_count=loop_count,
)
test_examples_part = ""
if test_input_examples:
test_examples_part = TEST_EXAMPLES_PROMPT.format(
test_input_examples=test_input_examples
)
has_init = find_init(ast.parse(self.source_code))
return (
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
f"{self.base_user_prompt.format(source_code=markdown_source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
f"{runtime_part}"
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
f"{test_examples_part}"
)
def extract_code_and_explanation_from_llm_res(
self, content: str
) -> CodeStrAndExplanation:
"""
Extract code block and explanation from LLM response.
"""
code = ""
explanation = ""
if result := extract_code_block_with_context(
content, language="python"
):
explanation, code, _ = result
extracted = CodeStrAndExplanation(
code=code, explanation=explanation
)
self.extracted_code_and_expl = extracted
return extracted
def parse_and_generate_candidate_schema(
self,
) -> OptimizeResponseItem | None:
"""
Post-process extracted code into a response item.
"""
if self.extracted_code_and_expl is None:
msg = "Call extract_code_and_explanation_from_llm_res first"
raise AttributeError(msg)
extracted = self.extracted_code_and_expl
if extracted.code == "":
return None
try:
op_id = str(uuid.uuid4())
cst_module = parse_module_to_cst(extracted.code)
self.code_and_explanation_before_post_processing[
op_id
] = CodeStrAndExplanation(
code=cst_module.code,
explanation=extracted.explanation,
)
original_cst_module = parse_module_to_cst(
self.source_code
)
postprocessed_list = optimizations_postprocessing_pipeline(
original_cst_module,
[
OptimizationCandidate(
cst_module=cst_module,
id=op_id,
explanation=extracted.explanation,
)
],
)
if len(postprocessed_list) == 0:
return None
postprocessed = postprocessed_list[0]
source = (
group_code(
{self.file_name: postprocessed.cst_module.code}
)
if self.file_name is not None
else postprocessed.cst_module.code
)
return OptimizeResponseItem(
explanation=postprocessed.explanation,
optimization_id=postprocessed.id,
source_code=source,
)
except (
ValueError,
cst.ParserSyntaxError,
):
log.warning("Error parsing optimization result", exc_info=True)
return None
def is_valid_code(self) -> bool:
"""
Check if extracted code is valid.
"""
return super().is_valid_code()
def validate_and_parse_source_code(
self,
code: str,
feature_version: tuple[int, ...],
) -> None:
"""
Validate single-file source code.
"""
compile(self.source_code, "source_code", "exec")
super().validate_and_parse_source_code(
self.source_code, feature_version
)
class MultiOptimizerContext(BaseOptimizerContext):
"""
Context for multi-file optimizations.
"""
def __init__(
self,
base_system_prompt: str,
base_user_prompt: str,
source_code: str,
diff_method: DiffMethod,
) -> None:
self.diff_method = diff_method
super().__init__(
base_system_prompt, base_user_prompt, source_code
)
self.original_file_to_code = split_markdown_code(source_code)
self._original_files_set = set(
self.original_file_to_code.keys()
)
def get_system_prompt(
self, python_version_str: str
) -> str:
"""
Format system prompt based on diff method.
"""
critical_instructions = ""
code_format_instructions = ""
if self.diff_method == DiffMethod.V4A:
code_format_instructions = V4A_DIFF_FORMAT_PROMPT
critical_instructions = (
"Begin your response with the diff followed by a"
" short explanation of the changes (without any"
" title or heading like 'Explanation' or"
" 'Changes')."
)
elif self.diff_method == DiffMethod.SEARCH_AND_REPLACE:
code_format_instructions = (
SEARCH_AND_REPLACE_FORMAT_PROMPT
)
critical_instructions = (
"Begin your response with the diff followed by a"
" short explanation of the changes (without any"
" title or heading like 'Explanation' or"
" 'Changes')."
)
elif self.diff_method == DiffMethod.NO_DIFF:
code_format_instructions = MARKDOWN_CONTEXT_PROMPT
critical_instructions = FULL_CODE_PROMPT_INSTRUCTIONS
return (
self.base_system_prompt.format(
python_version_str=python_version_str,
critical_instructions=critical_instructions,
)
+ "\n"
+ code_format_instructions
)
def get_user_prompt(
self,
dependency_code: str,
line_profiler_results: str | None,
*,
baseline_runtime_ns: int | None = None,
loop_count: int | None = None,
test_input_examples: str | None = None,
) -> str:
"""
Build user prompt for multi-file context.
"""
has_init = any(
find_init(ast.parse(code))
for code in self.original_file_to_code.values()
)
runtime_part = ""
if (
baseline_runtime_ns is not None
and loop_count is not None
):
runtime_part = RUNTIME_CONTEXT_PROMPT.format(
baseline_runtime_ns=baseline_runtime_ns,
baseline_runtime_human=_humanize_ns(
baseline_runtime_ns
),
loop_count=loop_count,
)
test_examples_part = ""
if test_input_examples:
test_examples_part = TEST_EXAMPLES_PROMPT.format(
test_input_examples=test_input_examples
)
return (
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
f"{self.base_user_prompt.format(source_code=self.source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
f"{runtime_part}"
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
f"{test_examples_part}"
)
def extract_code_and_explanation_from_llm_res(
self, content: str
) -> CodeStrAndExplanation:
"""
Extract code from LLM response based on diff method.
"""
markdown_code = ""
explanation = ""
if self.diff_method == DiffMethod.NO_DIFF:
explanation = (
content.split("```python:", maxsplit=1)[
0
].strip()
)
markdown_code = group_code(
split_markdown_code(content)
)
elif self.diff_method == DiffMethod.V4A:
diff: Diff = V4ADiff(
content=content,
source_code=self.original_file_to_code,
)
new_file_to_code = diff.run()
markdown_code = group_code(new_file_to_code)
patch_end_marker = "*** End Patch"
index_of_end = content.rfind(patch_end_marker)
explanation = content[
index_of_end + len(patch_end_marker) :
].strip()
elif self.diff_method == DiffMethod.SEARCH_AND_REPLACE:
diff = SearchAndReplaceDiff(
content=content,
source_code=self.original_file_to_code,
)
new_file_to_code = diff.run()
end_tag = "</replace_in_file>"
index_of_end = content.rfind(end_tag)
explanation = content[
index_of_end + len(end_tag) :
].strip()
markdown_code = group_code(new_file_to_code)
result = CodeStrAndExplanation(
code=markdown_code, explanation=explanation
)
self.extracted_code_and_expl = result
return result
def parse_and_generate_candidate_schema(
self,
) -> OptimizeResponseItem | None:
"""
Post-process each extracted file individually.
"""
if self.extracted_code_and_expl is None:
msg = "Call extract_code_and_explanation_from_llm_res first"
raise AttributeError(msg)
extracted = self.extracted_code_and_expl
if extracted.code == "":
return None
original_code_files = dict(
self.original_file_to_code.items()
)
extracted_code_files = dict(
split_markdown_code(extracted.code).items()
)
op_id = str(uuid.uuid4())
new_post_processed_code: dict[str, str] = {}
new_explanation: str = ""
self.code_and_explanation_before_post_processing[
op_id
] = CodeStrAndExplanation(
code=extracted.code, explanation=extracted.explanation
)
valid_extracted_files = {
f: c
for f, c in extracted_code_files.items()
if f in original_code_files
}
if not valid_extracted_files:
log.warning(
"No matching files. Original: %s, Extracted: %s",
list(original_code_files.keys()),
list(extracted_code_files.keys()),
)
return None
optimization_ignored = True
for original_file, new_code in valid_extracted_files.items():
original_code = original_code_files[original_file]
try:
new_cst_module = parse_module_to_cst(new_code)
original_cst_module = parse_module_to_cst(
original_code
)
code_and_explanation = OptimizationCandidate(
cst_module=new_cst_module,
id=f"{original_file}:post-processing",
explanation=extracted.explanation,
)
postprocessed_list = (
optimizations_postprocessing_pipeline(
original_cst_module,
[code_and_explanation],
)
)
if len(postprocessed_list) == 0:
new_post_processed_code[original_file] = (
new_code
)
continue
optimization_ignored = False
postprocessed = postprocessed_list[0]
if new_explanation == "":
new_explanation = postprocessed.explanation
new_post_processed_code[original_file] = (
postprocessed.cst_module.code
)
except (
ValueError,
cst.ParserSyntaxError,
):
log.warning(
"Error processing file %s",
original_file,
exc_info=True,
)
return None
if optimization_ignored:
return None
return OptimizeResponseItem(
optimization_id=op_id,
explanation=new_explanation,
source_code=group_code(new_post_processed_code),
)
def is_valid_code(self) -> bool:
"""
Validate extracted files are subset of originals.
"""
if self.extracted_code_and_expl is None:
return False
extracted_files = set(
split_markdown_code(
self.extracted_code_and_expl.code
).keys()
)
invalid_files = extracted_files - self._original_files_set
if invalid_files:
log.warning(
"LLM returned files not in original: %s",
invalid_files,
)
return False
return super().is_valid_code()
def validate_and_parse_source_code(
self,
code: str,
feature_version: tuple[int, ...],
) -> None:
"""
Validate each file in multi-file source code.
"""
final_code = code or self.source_code
code_blocks = split_markdown_code(final_code).values()
for code_block in code_blocks:
compile(code_block, "source_code", "exec")
super().validate_and_parse_source_code(
code_block, feature_version
)

View file

@ -0,0 +1,212 @@
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Any
import libcst as cst
from codeflash_api.llm._client import LLMOutputUnparseableError
from codeflash_api.llm._models import (
ANTHROPIC_CLAUDE_SONNET_4_5,
LLM,
OPENAI_GPT_5_MINI,
)
from codeflash_api.optimize._context import BaseOptimizerContext
if TYPE_CHECKING:
from codeflash_api.llm._client import LLMClient
from codeflash_api.optimize.schemas import OptimizeResponseItem
log = logging.getLogger(__name__)
OPENAI_MODEL: LLM = OPENAI_GPT_5_MINI
ANTHROPIC_MODEL: LLM = ANTHROPIC_CLAUDE_SONNET_4_5
MAX_OPTIMIZER_CALLS = 6
def get_model_distribution(
total_calls: int, max_calls: int
) -> list[tuple[LLM, int]]:
"""
Split calls between OpenAI and Anthropic models.
"""
final_total = min(total_calls, max_calls)
claude_calls = (final_total - 1) // 2
gpt_calls = final_total - claude_calls
return [
(OPENAI_MODEL, gpt_calls),
(ANTHROPIC_MODEL, claude_calls),
]
async def generate_optimization_candidate(
llm_client: LLMClient,
user_id: str,
ctx: BaseOptimizerContext,
trace_id: str,
*,
dependency_code: str | None = None,
optimize_model: LLM = OPENAI_MODEL,
python_version: tuple[int, int, int] = (3, 12, 9),
call_sequence: int | None = None,
baseline_runtime_ns: int | None = None,
loop_count: int | None = None,
line_profiler_results: str | None = None,
test_input_examples: str | None = None,
) -> tuple[OptimizeResponseItem | None, float | None, str]:
"""
Generate a single optimization candidate via LLM.
"""
log.info("Generating optimization candidate")
python_version_str = ".".join(str(x) for x in python_version)
system_prompt = ctx.get_system_prompt(python_version_str)
user_prompt = ctx.get_user_prompt(
dependency_code or "",
line_profiler_results,
baseline_runtime_ns=baseline_runtime_ns,
loop_count=loop_count,
test_input_examples=test_input_examples,
)
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
output = await llm_client.call(
llm=optimize_model,
messages=messages,
call_type="optimization",
trace_id=trace_id,
)
except LLMOutputUnparseableError as e:
return None, e.cost, optimize_model.name
except Exception:
log.exception("Failed to generate optimization")
return None, None, optimize_model.name
llm_cost = output.cost
ctx.extract_code_and_explanation_from_llm_res(output.content)
try:
res = await asyncio.to_thread(
ctx.parse_and_generate_candidate_schema
)
if res is not None and ctx.is_valid_code():
return res, llm_cost, optimize_model.name
except (ValueError, cst.ParserSyntaxError):
log.warning(
"Error parsing optimization result", exc_info=True
)
return None, llm_cost, optimize_model.name
async def optimize_python_code(
llm_client: LLMClient,
user_id: str,
ctx: BaseOptimizerContext,
trace_id: str,
original_source_code: str,
*,
dependency_code: str | None = None,
python_version: tuple[int, int, int] = (3, 12, 9),
n_candidates: int = 0,
baseline_runtime_ns: int | None = None,
loop_count: int | None = None,
line_profiler_results: str | None = None,
test_input_examples: str | None = None,
) -> tuple[
list[OptimizeResponseItem],
float,
dict[str, dict[str, str]],
dict[str, str],
]:
"""
Run parallel optimizations with multiple models.
"""
tasks: list[
tuple[
asyncio.Task[
tuple[
OptimizeResponseItem | None,
float | None,
str,
]
],
BaseOptimizerContext,
]
] = []
call_sequence = 1
if n_candidates == 0:
return [], 0.0, {}, {}
async with asyncio.TaskGroup() as tg:
for model, num_calls in get_model_distribution(
n_candidates, MAX_OPTIMIZER_CALLS
):
for _ in range(num_calls):
task_ctx = BaseOptimizerContext.get_dynamic_context(
ctx.base_system_prompt,
ctx.base_user_prompt,
original_source_code,
)
lp_for_this_call = (
line_profiler_results
if call_sequence % 2 == 1
else None
)
task = tg.create_task(
generate_optimization_candidate(
llm_client,
user_id=user_id,
ctx=task_ctx,
trace_id=trace_id,
dependency_code=dependency_code,
optimize_model=model,
python_version=python_version,
call_sequence=call_sequence,
baseline_runtime_ns=baseline_runtime_ns,
loop_count=loop_count,
line_profiler_results=lp_for_this_call,
test_input_examples=test_input_examples,
)
)
tasks.append((task, task_ctx))
call_sequence += 1
optimization_results: list[OptimizeResponseItem] = []
total_cost = 0.0
code_and_explanations: dict[str, dict[str, str]] = {}
optimization_models: dict[str, str] = {}
for task, task_ctx in tasks:
result, cost, model_name = task.result()
if cost:
total_cost += cost
if result is not None:
optimization_results.append(result)
optimization_models[result.optimization_id] = (
model_name
)
for (
op_id,
cei,
) in (
task_ctx.code_and_explanation_before_post_processing.items()
):
code_and_explanations[op_id] = {
"code": cei.code,
"explanation": cei.explanation,
}
return (
optimization_results,
total_cost,
code_and_explanations,
optimization_models,
)

View file

@ -0,0 +1,194 @@
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Annotated
from fastapi import APIRouter, Depends, HTTPException, Request
from codeflash_api.auth._deps import (
check_rate_limit,
require_auth,
track_usage,
)
from codeflash_api.diff._base import DiffMethod
if TYPE_CHECKING:
from codeflash_api.auth.models import AuthenticatedUser
from codeflash_api.optimize._context import (
BaseOptimizerContext,
parse_python_version,
validate_trace_id,
)
from codeflash_api.optimize._pipeline import optimize_python_code
from codeflash_api.optimize.schemas import (
OptimizeErrorResponse,
OptimizeRequest,
OptimizeResponse,
)
log = logging.getLogger(__name__)
router = APIRouter()
_PROMPTS_DIR = (
Path(__file__).parent.parent
/ "languages"
/ "python"
/ "prompts"
)
SYSTEM_PROMPT = (_PROMPTS_DIR / "system_prompt.md").read_text()
USER_PROMPT = (_PROMPTS_DIR / "user_prompt.md").read_text()
ASYNC_SYSTEM_PROMPT = (
(_PROMPTS_DIR / "async_system_prompt.md").read_text()
)
ASYNC_USER_PROMPT = (
(_PROMPTS_DIR / "async_user_prompt.md").read_text()
)
JIT_INSTRUCTIONS = (
(_PROMPTS_DIR / "jit_instructions.md").read_text()
)
def _validate_request(
data: OptimizeRequest,
ctx: BaseOptimizerContext,
) -> tuple[int, int, int]:
"""
Validate request fields and parse source code.
"""
if not data.source_code:
raise HTTPException(
status_code=400,
detail="Source code cannot be empty.",
)
if not validate_trace_id(data.trace_id):
raise HTTPException(
status_code=400,
detail="Invalid trace ID. Please provide a valid UUIDv4.",
)
if not data.python_version:
raise HTTPException(
status_code=400,
detail="Python version is required.",
)
try:
python_version = parse_python_version(
data.python_version
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail=(
"Invalid Python version, it should look like"
" 3.x.x. We only support Python 3.9 and above."
),
) from e
try:
ctx.validate_and_parse_source_code(
data.source_code,
feature_version=python_version[:2],
)
except SyntaxError as e:
raise HTTPException(
status_code=400,
detail=(
"Invalid source code. It is not valid Python"
" code. Please check syntax of your code."
),
) from e
return python_version
@router.post(
"/ai/optimize",
response_model=OptimizeResponse,
responses={
400: {"model": OptimizeErrorResponse},
422: {"model": OptimizeErrorResponse},
500: {"model": OptimizeErrorResponse},
},
)
async def optimize(
request: Request,
data: OptimizeRequest,
user: Annotated[
AuthenticatedUser, Depends(require_auth)
],
_rate: Annotated[None, Depends(check_rate_limit)],
_usage: Annotated[None, Depends(track_usage)],
) -> OptimizeResponse:
"""
Optimize Python code for performance using LLMs.
"""
system_prompt = (
ASYNC_SYSTEM_PROMPT if data.is_async else SYSTEM_PROMPT
)
user_prompt = (
ASYNC_USER_PROMPT if data.is_async else USER_PROMPT
)
if data.is_numerical_code:
system_prompt += f"\n{JIT_INSTRUCTIONS}\n"
ctx = BaseOptimizerContext.get_dynamic_context(
system_prompt,
user_prompt,
data.source_code,
DiffMethod.NO_DIFF,
)
try:
python_version = _validate_request(data, ctx)
except HTTPException:
raise
except Exception as e:
log.exception(
"Validation error: trace_id=%s", data.trace_id
)
raise HTTPException(
status_code=400,
detail="Request validation failed.",
) from e
llm_client = request.app.state.llm_client
(
optimization_results,
_llm_cost,
_code_and_explanations,
_optimization_models,
) = await optimize_python_code(
llm_client,
user_id=user.user_id,
ctx=ctx,
trace_id=data.trace_id,
original_source_code=data.source_code,
dependency_code=data.dependency_code,
python_version=python_version,
n_candidates=data.n_candidates,
baseline_runtime_ns=data.baseline_runtime_ns,
loop_count=data.loop_count,
line_profiler_results=data.line_profiler_results,
test_input_examples=data.test_input_examples,
)
if len(optimization_results) == 0:
log.warning(
"No optimizations found: trace_id=%s, repo=%s/%s,"
" n_candidates=%d",
data.trace_id,
data.repo_owner,
data.repo_name,
data.n_candidates,
)
raise HTTPException(
status_code=422,
detail=(
"Could not generate any optimizations."
" Please try again."
),
)
return OptimizeResponse(optimizations=optimization_results)

View file

@ -0,0 +1,86 @@
from __future__ import annotations
import enum
from typing import Self
from pydantic import BaseModel, model_validator
class OptimizedCandidateSource(str, enum.Enum):
"""
Source of an optimized candidate.
"""
OPTIMIZE = "OPTIMIZE"
OPTIMIZE_LP = "OPTIMIZE_LP"
REFINE = "REFINE"
REPAIR = "REPAIR"
ADAPTIVE = "ADAPTIVE"
JIT_REWRITE = "JIT_REWRITE"
class OptimizeRequest(BaseModel):
"""
Request body for POST /ai/optimize.
"""
source_code: str
dependency_code: str | None = None
trace_id: str
python_version: str | None = None
language: str = "python"
language_version: str | None = None
experiment_metadata: dict[str, str] | None = None
codeflash_version: str | None = None
current_username: str | None = None
repo_owner: str | None = None
repo_name: str | None = None
is_async: bool | None = False
n_candidates: int = 5
is_numerical_code: bool | None = None
rerun_trace_id: str | None = None
baseline_runtime_ns: int | None = None
loop_count: int | None = None
line_profiler_results: str | None = None
test_input_examples: str | None = None
@model_validator(mode="after")
def resolve_python_version(self) -> Self:
"""
Resolve python_version from language_version for backward compat.
"""
if (
self.python_version is None
and self.language_version is not None
and self.language == "python"
):
self.python_version = self.language_version
return self
class OptimizeResponseItem(BaseModel):
"""
A single optimization candidate in the response.
"""
source_code: str
explanation: str
optimization_id: str
parent_id: str | None = None
optimization_event_id: str | None = None
class OptimizeResponse(BaseModel):
"""
Successful response from POST /ai/optimize.
"""
optimizations: list[OptimizeResponseItem]
class OptimizeErrorResponse(BaseModel):
"""
Error response from POST /ai/optimize.
"""
error: str

View file

@ -0,0 +1,570 @@
from __future__ import annotations
import uuid
import pytest
from codeflash_api.diff._base import DiffMethod
from codeflash_api.languages.python._markdown import (
extract_all_code_from_markdown,
extract_code_block,
extract_code_block_with_context,
group_code,
is_multi_context,
split_markdown_code,
truncate_pathological_output,
wrap_code_in_markdown,
)
from codeflash_api.optimize._context import (
BaseOptimizerContext,
CodeStrAndExplanation,
MultiOptimizerContext,
SingleOptimizerContext,
_humanize_ns,
parse_python_version,
validate_trace_id,
)
from codeflash_api.optimize._pipeline import (
MAX_OPTIMIZER_CALLS,
get_model_distribution,
)
from codeflash_api.optimize.schemas import (
OptimizedCandidateSource,
OptimizeErrorResponse,
OptimizeRequest,
OptimizeResponse,
OptimizeResponseItem,
)
class TestTruncatePathologicalOutput:
"""Tests for truncate_pathological_output."""
def test_normal_code(self) -> None:
"""
Normal code is returned unchanged.
"""
code = "x = 1\ny = 2\n"
assert code == truncate_pathological_output(code)
def test_pathological(self) -> None:
"""
Repeated escape sequences are truncated.
"""
code = "x = 1\n\\''''''''''''more"
result = truncate_pathological_output(code)
assert "x = 1" in result
assert "more" not in result
class TestExtractAllCodeFromMarkdown:
"""Tests for extract_all_code_from_markdown."""
def test_single_block(self) -> None:
"""
Single code block is extracted.
"""
md = "```python\nx = 1\n```"
assert "x = 1\n" == extract_all_code_from_markdown(md)
def test_multiple_blocks(self) -> None:
"""
Multiple code blocks are joined.
"""
md = "```python\nx = 1\n```\ntext\n```python\ny = 2\n```"
result = extract_all_code_from_markdown(md)
assert "x = 1" in result
assert "y = 2" in result
def test_no_blocks(self) -> None:
"""
No code blocks returns empty string.
"""
assert "" == extract_all_code_from_markdown("no code here")
class TestExtractCodeBlock:
"""Tests for extract_code_block."""
def test_basic(self) -> None:
"""
First code block is extracted.
"""
md = "```python\nx = 1\n```"
assert "x = 1" == extract_code_block(md)
def test_with_filepath(self) -> None:
"""
Code block with filepath is extracted.
"""
md = "```python:main.py\nx = 1\n```"
assert "x = 1" == extract_code_block(md)
def test_no_block(self) -> None:
"""
No code block returns None.
"""
assert extract_code_block("no code") is None
class TestSplitMarkdownCode:
"""Tests for split_markdown_code."""
def test_basic(self) -> None:
"""
Filepath blocks are parsed into a dict.
"""
md = "```python:main.py\nx = 1\n```\n```python:util.py\ny = 2\n```"
result = split_markdown_code(md)
assert "main.py" in result
assert "util.py" in result
assert "x = 1" == result["main.py"]
def test_plain_blocks_ignored(self) -> None:
"""
Plain python blocks without filepath are ignored.
"""
md = "```python\nx = 1\n```"
result = split_markdown_code(md)
assert {} == result
def test_duplicate_paths(self) -> None:
"""
First occurrence wins for duplicate paths.
"""
md = "```python:a.py\nfirst\n```\n```python:a.py\nsecond\n```"
result = split_markdown_code(md)
assert "first" == result["a.py"]
class TestExtractCodeBlockWithContext:
"""Tests for extract_code_block_with_context."""
def test_basic(self) -> None:
"""
Code block and context are extracted.
"""
text = "explanation\n```python\nx = 1\n```\nafter"
result = extract_code_block_with_context(text)
assert result is not None
explanation, code, after = result
assert "explanation" == explanation
assert "x = 1\n" == code
assert "after" == after
def test_with_filepath(self) -> None:
"""
Filepath variant is preferred.
"""
text = "before\n```python:file.py\ny = 2\n```\nend"
result = extract_code_block_with_context(text)
assert result is not None
assert "y = 2\n" == result[1]
def test_no_block(self) -> None:
"""
No code block returns None.
"""
assert extract_code_block_with_context("no code") is None
class TestWrapCodeInMarkdown:
"""Tests for wrap_code_in_markdown."""
def test_basic(self) -> None:
"""
Code is wrapped in a markdown block.
"""
assert "```python\nx = 1\n```" == wrap_code_in_markdown(
"x = 1"
)
class TestGroupCode:
"""Tests for group_code."""
def test_single_file(self) -> None:
"""
Single file is formatted as markdown block.
"""
result = group_code({"main.py": "x = 1\n"})
assert "```python:main.py" in result
assert "x = 1" in result
def test_multiple_files(self) -> None:
"""
Multiple files are joined.
"""
result = group_code(
{"a.py": "x = 1\n", "b.py": "y = 2\n"}
)
assert "a.py" in result
assert "b.py" in result
class TestIsMultiContext:
"""Tests for is_multi_context."""
def test_multi(self) -> None:
"""
Markdown with filepath header is multi.
"""
assert is_multi_context("```python:file.py\nx=1\n```")
def test_single(self) -> None:
"""
Plain code is not multi.
"""
assert not is_multi_context("x = 1\n")
class TestHumanizeNs:
"""Tests for _humanize_ns."""
def test_nanoseconds(self) -> None:
"""
Sub-microsecond values show ns.
"""
assert "500ns" == _humanize_ns(500)
def test_microseconds(self) -> None:
"""
Microsecond range shows us.
"""
assert "1.5us" == _humanize_ns(1500)
def test_milliseconds(self) -> None:
"""
Millisecond range shows ms.
"""
assert "2.5ms" == _humanize_ns(2_500_000)
def test_seconds(self) -> None:
"""
Second range shows s.
"""
assert "1.00s" == _humanize_ns(1_000_000_000)
class TestParsePythonVersion:
"""Tests for parse_python_version."""
def test_valid(self) -> None:
"""
Valid version is parsed to a tuple.
"""
assert (3, 12, 9) == parse_python_version("3.12.9")
def test_empty(self) -> None:
"""
Empty version raises ValueError.
"""
with pytest.raises(ValueError, match="required"):
parse_python_version("")
def test_wrong_format(self) -> None:
"""
Missing parts raises ValueError.
"""
with pytest.raises(ValueError, match="format"):
parse_python_version("3.12")
def test_too_old(self) -> None:
"""
Python 3.8 is rejected.
"""
with pytest.raises(ValueError, match="3.9"):
parse_python_version("3.8.0")
def test_python2(self) -> None:
"""
Python 2 is rejected.
"""
with pytest.raises(ValueError, match="Python 3"):
parse_python_version("2.7.18")
class TestValidateTraceId:
"""Tests for validate_trace_id."""
def test_valid(self) -> None:
"""
Valid UUIDv4 passes.
"""
valid_id = str(uuid.uuid4())
assert validate_trace_id(valid_id)
def test_invalid(self) -> None:
"""
Invalid string fails.
"""
assert not validate_trace_id("not-a-uuid")
def test_exp_suffix(self) -> None:
"""
EXP0 suffix is normalized and accepted.
"""
base = str(uuid.uuid4())
exp_id = base[:-4] + "EXP0"
assert validate_trace_id(exp_id)
class TestCodeStrAndExplanation:
"""Tests for CodeStrAndExplanation."""
def test_attributes(self) -> None:
"""
Code and explanation are stored.
"""
obj = CodeStrAndExplanation(
code="x = 1", explanation="optimized"
)
assert "x = 1" == obj.code
assert "optimized" == obj.explanation
class TestBaseOptimizerContext:
"""Tests for BaseOptimizerContext."""
def test_get_dynamic_context_single(self) -> None:
"""
Plain source code returns SingleOptimizerContext.
"""
ctx = BaseOptimizerContext.get_dynamic_context(
"system", "user", "x = 1\n"
)
assert isinstance(ctx, SingleOptimizerContext)
def test_get_dynamic_context_multi(self) -> None:
"""
Multi-file markdown returns MultiOptimizerContext.
"""
source = "```python:a.py\nx = 1\n```\n```python:b.py\ny = 2\n```"
ctx = BaseOptimizerContext.get_dynamic_context(
"system", "user", source
)
assert isinstance(ctx, MultiOptimizerContext)
def test_get_dynamic_context_single_file_multi_format(
self,
) -> None:
"""
Single file in multi format with NO_DIFF returns Single.
"""
source = "```python:a.py\nx = 1\n```"
ctx = BaseOptimizerContext.get_dynamic_context(
"system", "user", source, DiffMethod.NO_DIFF
)
assert isinstance(ctx, SingleOptimizerContext)
def test_is_valid_code_empty(self) -> None:
"""
None extracted code is invalid.
"""
ctx = BaseOptimizerContext("s", "u", "x = 1")
assert not ctx.is_valid_code()
def test_validate_and_parse_valid(self) -> None:
"""
Valid Python passes validation.
"""
ctx = BaseOptimizerContext("s", "u", "x = 1\n")
ctx.validate_and_parse_source_code(
"x = 1\n", feature_version=(3, 12)
)
def test_validate_and_parse_empty(self) -> None:
"""
Empty body raises SyntaxError.
"""
ctx = BaseOptimizerContext("s", "u", "# comment\n")
with pytest.raises(SyntaxError, match="empty"):
ctx.validate_and_parse_source_code(
"# comment\n", feature_version=(3, 12)
)
class TestSingleOptimizerContext:
"""Tests for SingleOptimizerContext."""
def test_get_system_prompt(self) -> None:
"""
System prompt includes version and instructions.
"""
ctx = SingleOptimizerContext(
"{python_version_str} {critical_instructions}",
"user",
"x = 1\n",
)
result = ctx.get_system_prompt("3.12.9")
assert "3.12.9" in result
assert "FULL" in result
def test_extract_code_from_llm_response(self) -> None:
"""
Code block is extracted from LLM markdown response.
"""
ctx = SingleOptimizerContext("s", "u", "x = 1\n")
content = "explanation\n```python\ny = 2\n```"
result = ctx.extract_code_and_explanation_from_llm_res(
content
)
assert "y = 2\n" == result.code
assert "explanation" == result.explanation
def test_parse_generates_response_item(self) -> None:
"""
Valid extracted code produces a response item.
"""
ctx = SingleOptimizerContext("s", "u", "x = 1\n")
ctx.extract_code_and_explanation_from_llm_res(
"faster\n```python\ny = 2\n```"
)
result = ctx.parse_and_generate_candidate_schema()
assert result is not None
assert result.source_code.strip() == "y = 2"
def test_parse_returns_none_for_empty(self) -> None:
"""
Empty extracted code returns None.
"""
ctx = SingleOptimizerContext("s", "u", "x = 1\n")
ctx.extracted_code_and_expl = CodeStrAndExplanation(
code="", explanation=""
)
assert ctx.parse_and_generate_candidate_schema() is None
class TestMultiOptimizerContext:
"""Tests for MultiOptimizerContext."""
def test_get_system_prompt_no_diff(self) -> None:
"""
NO_DIFF uses markdown format prompt.
"""
source = "```python:a.py\nx = 1\n```\n```python:b.py\ny = 2\n```"
ctx = MultiOptimizerContext(
"{python_version_str} {critical_instructions}",
"user",
source,
diff_method=DiffMethod.NO_DIFF,
)
result = ctx.get_system_prompt("3.12.9")
assert "3.12.9" in result
assert "Code Input format" in result
def test_is_valid_code_rejects_new_files(self) -> None:
"""
Extracted files not in original are rejected.
"""
source = "```python:a.py\nx = 1\n```"
ctx = MultiOptimizerContext(
"s",
"u",
source,
diff_method=DiffMethod.NO_DIFF,
)
ctx.extracted_code_and_expl = CodeStrAndExplanation(
code="```python:new_file.py\nz = 3\n```",
explanation="added",
)
assert not ctx.is_valid_code()
class TestGetModelDistribution:
"""Tests for get_model_distribution."""
def test_default_five(self) -> None:
"""
5 candidates: 3 GPT + 2 Claude.
"""
dist = get_model_distribution(5, MAX_OPTIMIZER_CALLS)
assert 2 == len(dist)
gpt_model, gpt_calls = dist[0]
claude_model, claude_calls = dist[1]
assert 3 == gpt_calls
assert 2 == claude_calls
def test_capped(self) -> None:
"""
Candidates are capped at max_calls.
"""
dist = get_model_distribution(100, MAX_OPTIMIZER_CALLS)
total = sum(calls for _, calls in dist)
assert MAX_OPTIMIZER_CALLS == total
def test_single(self) -> None:
"""
Single candidate: 1 GPT + 0 Claude.
"""
dist = get_model_distribution(1, MAX_OPTIMIZER_CALLS)
total = sum(calls for _, calls in dist)
assert 1 == total
class TestOptimizeSchemas:
"""Tests for Pydantic optimize schemas."""
def test_request_defaults(self) -> None:
"""
Default fields are set correctly.
"""
req = OptimizeRequest(
source_code="x = 1",
trace_id=str(uuid.uuid4()),
)
assert 5 == req.n_candidates
assert "python" == req.language
assert req.is_async is False
def test_request_version_resolution(self) -> None:
"""
language_version backfills python_version.
"""
req = OptimizeRequest(
source_code="x = 1",
trace_id=str(uuid.uuid4()),
language_version="3.12.9",
)
assert "3.12.9" == req.python_version
def test_response_item(self) -> None:
"""
Response item is serializable.
"""
item = OptimizeResponseItem(
source_code="x = 2",
explanation="faster",
optimization_id="abc",
)
data = item.model_dump()
assert "x = 2" == data["source_code"]
assert data["parent_id"] is None
def test_response(self) -> None:
"""
Response wraps items.
"""
resp = OptimizeResponse(
optimizations=[
OptimizeResponseItem(
source_code="x = 2",
explanation="faster",
optimization_id="abc",
)
]
)
assert 1 == len(resp.optimizations)
def test_error_response(self) -> None:
"""
Error response has error field.
"""
err = OptimizeErrorResponse(error="bad")
assert "bad" == err.error
def test_candidate_source_enum(self) -> None:
"""
Enum values are accessible.
"""
assert "OPTIMIZE" == OptimizedCandidateSource.OPTIMIZE.value

View file

@ -103,6 +103,15 @@ ignore = [
"PLC0415", # conditional imports for event loop safety (clients recreated on loop change)
"TRY301", # raise inside try is the intended pattern for cost-tracking on unsupported model type
]
"packages/codeflash-api/src/codeflash_api/optimize/_context.py" = [
"PLR2004", # magic values in faithfully ported version parsing and humanize_ns
]
"packages/codeflash-api/src/codeflash_api/optimize/_pipeline.py" = [
"PLR0913", # faithfully ported signatures require many args for LLM call orchestration
]
"packages/codeflash-api/src/codeflash_api/_app.py" = [
"PLC0415", # local import for router registration avoids circular imports
]
"packages/codeflash-core/src/codeflash_core/_model.py" = [
"C901", # humanize_runtime is complex but faithfully ported
"PLR2004", # magic values in humanize_runtime thresholds