add examples

This commit is contained in:
aseembits93 2026-01-23 18:32:05 -08:00
parent 85344f5fd4
commit eb9b3dff1a

View file

@ -18,40 +18,36 @@ Each framework uses different compilation strategies to accelerate Python code:
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:
- **`@jit`** - General-purpose JIT compilation with optional flags. Here is a non-exhaustive options which codeflash would apply on the function to optimize it via numba jit compilation.
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
- **`@jit`** - General-purpose JIT compilation with optional flags.
- **`nopython=True`** - Compiles to machine code without falling back to the Python interpreter.
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag.
- **`cache=True`** - cache compiled function to disk which reduces future runtimes.
- **`parallel=True`** - Parallalizes code inside loops.
### PyTorch
PyTorch provides multiple compilation approaches:
PyTorch provides JIT compilation through `torch.compile()`, the recommended compilation API introduced in PyTorch 2.0. It uses TorchDynamo to capture Python bytecode and TorchInductor to generate optimized kernels.
- **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
- **`torch.compile()`** - Compiles a function or module for optimized execution.
- **`mode`** - Controls the compilation strategy:
- `"default"` - Balanced compilation with moderate optimization.
- `"reduce-overhead"` - Minimizes Python overhead using CUDA graphs, ideal for small batches.
- `"max-autotune"` - Spends more time autotuning to find the fastest kernels.
- **`fullgraph=True`** - Requires the entire function to be captured as a single graph. Raises an error if graph breaks occur, useful for ensuring complete optimization.
- **`dynamic=True`** - Enables dynamic shape support, allowing the compiled function to handle varying input sizes without recompilation.
### TensorFlow
TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation:
TensorFlow uses `@tf.function` to compile Python functions into optimized TensorFlow graphs. When combined with XLA (Accelerated Linear Algebra), it can generate highly optimized machine code for both CPU and GPU.
- **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
- **`@tf.function`** - Converts Python functions into TensorFlow graphs for optimized execution.
- **`jit_compile=True`** - Enables XLA compilation, which performs whole-function optimization including operation fusion, memory layout optimization, and target-specific code generation.
### JAX
JAX captures side-effect-free operations and optimizes them:
JAX uses XLA to JIT compile pure functions into optimized machine code. It emphasizes functional programming patterns and captures side-effect-free operations for optimization.
- **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
- **`@jax.jit`** - JIT compiles functions using XLA with automatic operation fusion.
## How Codeflash Optimizes with JIT
@ -60,34 +56,209 @@ When Codeflash identifies a function that could benefit from JIT compilation, it
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components.
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements.
## Accurate Benchmarking with GPU Code
## Accurate Benchmarking on Non-CPU devices
Since GPU operations execute asynchronously, Codeflash automatically inserts synchronization barriers before measuring performance. This ensures timing measurements reflect actual computation time rather than just the time to queue operations:
- **PyTorch**: Uses `torch.cuda.synchronize()` or `torch.mps.synchronize()` depending on the device
- **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete
- **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization
- **PyTorch**: Uses `torch.cuda.synchronize()` (NVIDIA GPUs) or `torch.mps.synchronize()` (MacOS Metal Performance Shaders) depending on the device.
- **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete.
- **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization.
## When JIT Compilation Helps
JIT compilation is most effective for:
- Numerical computations with loops that can't be easily vectorized
- Custom algorithms not covered by existing optimized libraries
- Functions that are called repeatedly with consistent input types
- Numerical computations with loops that can't be easily vectorized.
- Custom algorithms not covered by existing optimized libraries.
- Functions that are called repeatedly with consistent input types.
- Code that benefits from hardware-specific optimizations (SIMD, GPU acceleration)
### Example
#### Function Definition
```python
import torch
def complex_activation(x):
"""A custom activation with many small operations - compile makes a huge difference"""
# Many sequential element-wise ops create kernel launch overhead
x = torch.sin(x)
x = x * torch.cos(x)
x = x + torch.exp(-x.abs())
x = x / (1 + x.pow(2))
x = torch.tanh(x) * torch.sigmoid(x)
x = x - 0.5 * x.pow(3)
return x
```
#### Benchmarking Snippet (replace `cuda` with `mps` to run on your Mac)
```python
import time
# Create compiled version
complex_activation_compiled = torch.compile(complex_activation)
# Benchmark
x = torch.randn(1000, 1000, device='cuda')
# Warmup
for _ in range(10):
_ = complex_activation(x)
_ = complex_activation_compiled(x)
# Time uncompiled
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
y = complex_activation(x)
torch.cuda.synchronize()
uncompiled_time = time.time() - start
# Time compiled
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
y = complex_activation_compiled(x)
torch.cuda.synchronize()
compiled_time = time.time() - start
print(f"Uncompiled: {uncompiled_time:.4f}s")
print(f"Compiled: {compiled_time:.4f}s")
print(f"Speedup: {uncompiled_time/compiled_time:.2f}x")
```
Expected Output on CUDA
```
Uncompiled: 0.0176s
Compiled: 0.0063s
Speedup: 2.80x
```
Here, JIT compilation via `torch.compile` is the only viable option because
1. Already vectorized - All operations are already PyTorch tensor ops.
2. Multiple Kernel Launches - Uncompiled code launches ~10 separate kernels. torch.compile fuses them into 1-2 kernels, eliminating kernel launch overhead.
3. No algorithmic improvement - The computation itself is already optimal.
4. Python overhead elimination - Removes Python interpreter overhead between operations.
## When JIT Compilation May Not Help
JIT compilation may not provide speedups when:
- The code already uses highly optimized libraries (e.g., NumPy with MKL, cuBLAS, cuDNN)
- Functions have variable input types or shapes that prevent effective compilation
- The compilation overhead exceeds the runtime savings for short-running functions
- The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize
- The code already uses highly optimized libraries (e.g., NumPy with MKL, cuBLAS, cuDNN).
- Functions have variable input types or shapes that prevent effective compilation.
- The compilation overhead exceeds the runtime savings for short-running functions.
- The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize.
### Example
#### Function Definition
```
def adaptive_processing(x, threshold=0.5):
"""Function with data-dependent control flow - compile struggles here"""
# Check how many values exceed threshold (data-dependent!)
mask = x > threshold
num_large = mask.sum().item() # .item() causes graph break
if num_large > x.numel() * 0.3:
# Path 1: Many large values - use expensive operation
result = torch.matmul(x, x.T) # Already optimized by cuBLAS
result = result.mean(dim=0)
else:
# Path 2: Few large values - use cheap operation
result = x.mean(dim=1)
return result
```
#### Benchmarking Snippet (replace `cuda` with `mps` to run on your Mac)
```
# Create compiled version
adaptive_processing_compiled = torch.compile(adaptive_processing)
# Test with data that causes branch variation
x = torch.randn(500, 500, device='cuda')
# Warmup
for _ in range(10):
_ = adaptive_processing(x)
_ = adaptive_processing_compiled(x)
# Benchmark with varying data (causes recompilation)
torch.cuda.synchronize()
start = time.time()
for i in range(100):
# Vary the data to trigger different branches
x_test = torch.randn(500, 500, device='cuda') + (i % 2)
y = adaptive_processing(x_test)
torch.cuda.synchronize()
uncompiled_time = time.time() - start
torch.cuda.synchronize()
start = time.time()
for i in range(100):
x_test = torch.randn(500, 500, device='cuda') + (i % 2)
y = adaptive_processing_compiled(x_test) # Recompiles frequently!
torch.cuda.synchronize()
compiled_time = time.time() - start
print(f"Uncompiled: {uncompiled_time:.4f}s")
print(f"Compiled: {compiled_time:.4f}s")
print(f"Slowdown: {compiled_time/uncompiled_time:.2f}x")
```
Expected Output on CUDA
```
Uncompiled: 0.0296s
Compiled: 0.2847s
Slowdown: 9.63x
```
Why `torch.compile` is detrimental here:
1. Graph breaks - `.item()` forces a graph break, negating compile benefits.
2. Recompilation overhead - Different branches cause expensive recompilation each time.
3. Dynamic control flow - Data-dependent conditionals can't be optimized away.
4. Already optimized ops - `matmul` already uses `cuBLAS`; compile adds overhead without benefit.
#### Better Optimization Strategy
```python
def optimized_version(x, threshold=0.5):
"""Remove data-dependent control flow - vectorize instead"""
mask = (x > threshold).float()
weight = (mask.mean() > 0.3).float() # Keep on GPU
# Compute both paths, blend based on weight (branchless)
expensive = torch.matmul(x, x.T).mean(dim=0)
cheap = x.mean(dim=1).squeeze()
# Pad cheap result to match expensive dimensions
cheap_padded = cheap.expand(expensive.shape[0])
result = weight * expensive + (1 - weight) * cheap_padded
return result
```
Expected Output on CUDA
```
Optimized: 0.0277s
Speedup compared to Uncompiled: 1.57x
```
Key improvements:
1. Eliminate `.item()` - Keep computation on GPU.
2. Branchless execution - Compute both paths, blend results.
3. Vectorization - Replace conditionals with masked operations.
4. Reduce Python overhead - Minimize host-device synchronization.
## Configuration
JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.
When running tests with coverage measurement, Codeflash temporarily disables JIT compilation to ensure accurate coverage data, then re-enables it for performance benchmarking.
JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.