add examples
This commit is contained in:
parent
85344f5fd4
commit
eb9b3dff1a
1 changed files with 207 additions and 36 deletions
|
|
@ -18,40 +18,36 @@ Each framework uses different compilation strategies to accelerate Python code:
|
|||
|
||||
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:
|
||||
|
||||
- **`@jit`** - General-purpose JIT compilation with optional flags. Here is a non-exhaustive options which codeflash would apply on the function to optimize it via numba jit compilation.
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
- **`@jit`** - General-purpose JIT compilation with optional flags.
|
||||
- **`nopython=True`** - Compiles to machine code without falling back to the Python interpreter.
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag.
|
||||
- **`cache=True`** - cache compiled function to disk which reduces future runtimes.
|
||||
- **`parallel=True`** - Parallalizes code inside loops.
|
||||
|
||||
### PyTorch
|
||||
|
||||
PyTorch provides multiple compilation approaches:
|
||||
PyTorch provides JIT compilation through `torch.compile()`, the recommended compilation API introduced in PyTorch 2.0. It uses TorchDynamo to capture Python bytecode and TorchInductor to generate optimized kernels.
|
||||
|
||||
- **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
- **`torch.compile()`** - Compiles a function or module for optimized execution.
|
||||
- **`mode`** - Controls the compilation strategy:
|
||||
- `"default"` - Balanced compilation with moderate optimization.
|
||||
- `"reduce-overhead"` - Minimizes Python overhead using CUDA graphs, ideal for small batches.
|
||||
- `"max-autotune"` - Spends more time autotuning to find the fastest kernels.
|
||||
- **`fullgraph=True`** - Requires the entire function to be captured as a single graph. Raises an error if graph breaks occur, useful for ensuring complete optimization.
|
||||
- **`dynamic=True`** - Enables dynamic shape support, allowing the compiled function to handle varying input sizes without recompilation.
|
||||
|
||||
### TensorFlow
|
||||
|
||||
TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation:
|
||||
TensorFlow uses `@tf.function` to compile Python functions into optimized TensorFlow graphs. When combined with XLA (Accelerated Linear Algebra), it can generate highly optimized machine code for both CPU and GPU.
|
||||
|
||||
- **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
- **`@tf.function`** - Converts Python functions into TensorFlow graphs for optimized execution.
|
||||
- **`jit_compile=True`** - Enables XLA compilation, which performs whole-function optimization including operation fusion, memory layout optimization, and target-specific code generation.
|
||||
|
||||
### JAX
|
||||
|
||||
JAX captures side-effect-free operations and optimizes them:
|
||||
JAX uses XLA to JIT compile pure functions into optimized machine code. It emphasizes functional programming patterns and captures side-effect-free operations for optimization.
|
||||
|
||||
- **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
- **`@jax.jit`** - JIT compiles functions using XLA with automatic operation fusion.
|
||||
|
||||
## How Codeflash Optimizes with JIT
|
||||
|
||||
|
|
@ -60,34 +56,209 @@ When Codeflash identifies a function that could benefit from JIT compilation, it
|
|||
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components.
|
||||
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements.
|
||||
|
||||
## Accurate Benchmarking with GPU Code
|
||||
## Accurate Benchmarking on Non-CPU devices
|
||||
|
||||
Since GPU operations execute asynchronously, Codeflash automatically inserts synchronization barriers before measuring performance. This ensures timing measurements reflect actual computation time rather than just the time to queue operations:
|
||||
|
||||
- **PyTorch**: Uses `torch.cuda.synchronize()` or `torch.mps.synchronize()` depending on the device
|
||||
- **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete
|
||||
- **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization
|
||||
- **PyTorch**: Uses `torch.cuda.synchronize()` (NVIDIA GPUs) or `torch.mps.synchronize()` (MacOS Metal Performance Shaders) depending on the device.
|
||||
- **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete.
|
||||
- **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization.
|
||||
|
||||
## When JIT Compilation Helps
|
||||
|
||||
JIT compilation is most effective for:
|
||||
|
||||
- Numerical computations with loops that can't be easily vectorized
|
||||
- Custom algorithms not covered by existing optimized libraries
|
||||
- Functions that are called repeatedly with consistent input types
|
||||
- Numerical computations with loops that can't be easily vectorized.
|
||||
- Custom algorithms not covered by existing optimized libraries.
|
||||
- Functions that are called repeatedly with consistent input types.
|
||||
- Code that benefits from hardware-specific optimizations (SIMD, GPU acceleration)
|
||||
|
||||
### Example
|
||||
|
||||
#### Function Definition
|
||||
|
||||
```python
|
||||
import torch
|
||||
def complex_activation(x):
|
||||
"""A custom activation with many small operations - compile makes a huge difference"""
|
||||
# Many sequential element-wise ops create kernel launch overhead
|
||||
x = torch.sin(x)
|
||||
x = x * torch.cos(x)
|
||||
x = x + torch.exp(-x.abs())
|
||||
x = x / (1 + x.pow(2))
|
||||
x = torch.tanh(x) * torch.sigmoid(x)
|
||||
x = x - 0.5 * x.pow(3)
|
||||
return x
|
||||
```
|
||||
|
||||
#### Benchmarking Snippet (replace `cuda` with `mps` to run on your Mac)
|
||||
|
||||
```python
|
||||
import time
|
||||
# Create compiled version
|
||||
complex_activation_compiled = torch.compile(complex_activation)
|
||||
|
||||
# Benchmark
|
||||
x = torch.randn(1000, 1000, device='cuda')
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = complex_activation(x)
|
||||
_ = complex_activation_compiled(x)
|
||||
|
||||
# Time uncompiled
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
y = complex_activation(x)
|
||||
torch.cuda.synchronize()
|
||||
uncompiled_time = time.time() - start
|
||||
|
||||
# Time compiled
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
y = complex_activation_compiled(x)
|
||||
torch.cuda.synchronize()
|
||||
compiled_time = time.time() - start
|
||||
|
||||
print(f"Uncompiled: {uncompiled_time:.4f}s")
|
||||
print(f"Compiled: {compiled_time:.4f}s")
|
||||
print(f"Speedup: {uncompiled_time/compiled_time:.2f}x")
|
||||
```
|
||||
|
||||
Expected Output on CUDA
|
||||
|
||||
```
|
||||
Uncompiled: 0.0176s
|
||||
Compiled: 0.0063s
|
||||
Speedup: 2.80x
|
||||
```
|
||||
|
||||
Here, JIT compilation via `torch.compile` is the only viable option because
|
||||
1. Already vectorized - All operations are already PyTorch tensor ops.
|
||||
2. Multiple Kernel Launches - Uncompiled code launches ~10 separate kernels. torch.compile fuses them into 1-2 kernels, eliminating kernel launch overhead.
|
||||
3. No algorithmic improvement - The computation itself is already optimal.
|
||||
4. Python overhead elimination - Removes Python interpreter overhead between operations.
|
||||
|
||||
|
||||
## When JIT Compilation May Not Help
|
||||
|
||||
JIT compilation may not provide speedups when:
|
||||
|
||||
- The code already uses highly optimized libraries (e.g., NumPy with MKL, cuBLAS, cuDNN)
|
||||
- Functions have variable input types or shapes that prevent effective compilation
|
||||
- The compilation overhead exceeds the runtime savings for short-running functions
|
||||
- The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize
|
||||
- The code already uses highly optimized libraries (e.g., NumPy with MKL, cuBLAS, cuDNN).
|
||||
- Functions have variable input types or shapes that prevent effective compilation.
|
||||
- The compilation overhead exceeds the runtime savings for short-running functions.
|
||||
- The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize.
|
||||
|
||||
### Example
|
||||
|
||||
#### Function Definition
|
||||
|
||||
```
|
||||
def adaptive_processing(x, threshold=0.5):
|
||||
"""Function with data-dependent control flow - compile struggles here"""
|
||||
# Check how many values exceed threshold (data-dependent!)
|
||||
mask = x > threshold
|
||||
num_large = mask.sum().item() # .item() causes graph break
|
||||
|
||||
if num_large > x.numel() * 0.3:
|
||||
# Path 1: Many large values - use expensive operation
|
||||
result = torch.matmul(x, x.T) # Already optimized by cuBLAS
|
||||
result = result.mean(dim=0)
|
||||
else:
|
||||
# Path 2: Few large values - use cheap operation
|
||||
result = x.mean(dim=1)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
#### Benchmarking Snippet (replace `cuda` with `mps` to run on your Mac)
|
||||
|
||||
```
|
||||
# Create compiled version
|
||||
adaptive_processing_compiled = torch.compile(adaptive_processing)
|
||||
|
||||
# Test with data that causes branch variation
|
||||
x = torch.randn(500, 500, device='cuda')
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = adaptive_processing(x)
|
||||
_ = adaptive_processing_compiled(x)
|
||||
|
||||
# Benchmark with varying data (causes recompilation)
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for i in range(100):
|
||||
# Vary the data to trigger different branches
|
||||
x_test = torch.randn(500, 500, device='cuda') + (i % 2)
|
||||
y = adaptive_processing(x_test)
|
||||
torch.cuda.synchronize()
|
||||
uncompiled_time = time.time() - start
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for i in range(100):
|
||||
x_test = torch.randn(500, 500, device='cuda') + (i % 2)
|
||||
y = adaptive_processing_compiled(x_test) # Recompiles frequently!
|
||||
torch.cuda.synchronize()
|
||||
compiled_time = time.time() - start
|
||||
|
||||
print(f"Uncompiled: {uncompiled_time:.4f}s")
|
||||
print(f"Compiled: {compiled_time:.4f}s")
|
||||
print(f"Slowdown: {compiled_time/uncompiled_time:.2f}x")
|
||||
```
|
||||
|
||||
Expected Output on CUDA
|
||||
|
||||
```
|
||||
Uncompiled: 0.0296s
|
||||
Compiled: 0.2847s
|
||||
Slowdown: 9.63x
|
||||
```
|
||||
|
||||
Why `torch.compile` is detrimental here:
|
||||
|
||||
1. Graph breaks - `.item()` forces a graph break, negating compile benefits.
|
||||
2. Recompilation overhead - Different branches cause expensive recompilation each time.
|
||||
3. Dynamic control flow - Data-dependent conditionals can't be optimized away.
|
||||
4. Already optimized ops - `matmul` already uses `cuBLAS`; compile adds overhead without benefit.
|
||||
|
||||
#### Better Optimization Strategy
|
||||
|
||||
```python
|
||||
def optimized_version(x, threshold=0.5):
|
||||
"""Remove data-dependent control flow - vectorize instead"""
|
||||
mask = (x > threshold).float()
|
||||
weight = (mask.mean() > 0.3).float() # Keep on GPU
|
||||
|
||||
# Compute both paths, blend based on weight (branchless)
|
||||
expensive = torch.matmul(x, x.T).mean(dim=0)
|
||||
cheap = x.mean(dim=1).squeeze()
|
||||
|
||||
# Pad cheap result to match expensive dimensions
|
||||
cheap_padded = cheap.expand(expensive.shape[0])
|
||||
|
||||
result = weight * expensive + (1 - weight) * cheap_padded
|
||||
return result
|
||||
```
|
||||
|
||||
Expected Output on CUDA
|
||||
|
||||
```
|
||||
Optimized: 0.0277s
|
||||
Speedup compared to Uncompiled: 1.57x
|
||||
```
|
||||
|
||||
|
||||
Key improvements:
|
||||
|
||||
1. Eliminate `.item()` - Keep computation on GPU.
|
||||
2. Branchless execution - Compute both paths, blend results.
|
||||
3. Vectorization - Replace conditionals with masked operations.
|
||||
4. Reduce Python overhead - Minimize host-device synchronization.
|
||||
|
||||
## Configuration
|
||||
|
||||
JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.
|
||||
|
||||
When running tests with coverage measurement, Codeflash temporarily disables JIT compilation to ensure accurate coverage data, then re-enables it for performance benchmarking.
|
||||
JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.
|
||||
Loading…
Reference in a new issue