first draft, need to refine
This commit is contained in:
parent
70a069bc49
commit
454f20d6fc
2 changed files with 87 additions and 1 deletions
|
|
@ -66,7 +66,8 @@
|
|||
"group": "🧠 Core Concepts",
|
||||
"pages": [
|
||||
"codeflash-concepts/how-codeflash-works",
|
||||
"codeflash-concepts/benchmarking"
|
||||
"codeflash-concepts/benchmarking",
|
||||
"support-for-jit/index"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
85
docs/support-for-jit/index.mdx
Normal file
85
docs/support-for-jit/index.mdx
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
---
|
||||
title: "Support for Just-in-Time Compilation"
|
||||
description: "Learn how Codeflash optimizes code using JIT compilation with Numba, PyTorch, TensorFlow, and JAX"
|
||||
icon: "bolt"
|
||||
sidebarTitle: "JIT Compilation"
|
||||
keywords: ["JIT", "just-in-time", "numba", "pytorch", "tensorflow", "jax", "GPU", "CUDA", "compilation", "performance"]
|
||||
---
|
||||
|
||||
# Support for Just-in-Time Compilation
|
||||
|
||||
Codeflash supports optimizing code using Just-in-Time (JIT) compilation. This allows Codeflash to suggest optimizations that leverage JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**.
|
||||
|
||||
## Supported JIT Frameworks
|
||||
|
||||
Each framework uses different compilation strategies to accelerate Python code:
|
||||
|
||||
### Numba
|
||||
|
||||
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:
|
||||
|
||||
- **`@jit` / `@njit`** - General-purpose JIT compilation with `nopython` mode for removing Python interpreter overhead
|
||||
- **`parallel=True`** - Enables automatic SIMD parallelization
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`@vectorize` / `@guvectorize`** - Creates NumPy universal functions (ufuncs)
|
||||
- **`@cuda.jit`** - Compiles functions to run on NVIDIA GPUs
|
||||
|
||||
### PyTorch
|
||||
|
||||
PyTorch provides multiple compilation approaches:
|
||||
|
||||
- **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs
|
||||
- **`torch.jit.script`** - Compiles functions using TorchScript
|
||||
- **`torch.jit.trace`** - Traces tensor operations to create optimized execution graphs
|
||||
|
||||
### TensorFlow
|
||||
|
||||
TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation:
|
||||
|
||||
- **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA
|
||||
|
||||
### JAX
|
||||
|
||||
JAX captures side-effect-free operations and optimizes them:
|
||||
|
||||
- **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance
|
||||
|
||||
## How Codeflash Optimizes with JIT
|
||||
|
||||
When Codeflash identifies a function that could benefit from JIT compilation, it:
|
||||
|
||||
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components
|
||||
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter type requirements
|
||||
3. **Adds GPU synchronization calls** for accurate profiling when code runs on GPU, since GPU operations are inherently non-blocking
|
||||
|
||||
## Accurate Benchmarking with GPU Code
|
||||
|
||||
Since GPU operations execute asynchronously, Codeflash automatically inserts synchronization barriers before measuring performance. This ensures timing measurements reflect actual computation time rather than just the time to queue operations:
|
||||
|
||||
- **PyTorch**: Uses `torch.cuda.synchronize()` or `torch.mps.synchronize()` depending on the device
|
||||
- **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete
|
||||
- **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization
|
||||
|
||||
## When JIT Compilation Helps
|
||||
|
||||
JIT compilation is most effective for:
|
||||
|
||||
- Numerical computations with loops that can't be easily vectorized
|
||||
- Custom algorithms not covered by existing optimized libraries
|
||||
- Functions that are called repeatedly with consistent input types
|
||||
- Code that benefits from hardware-specific optimizations (SIMD, GPU acceleration)
|
||||
|
||||
## When JIT Compilation May Not Help
|
||||
|
||||
JIT compilation may not provide speedups when:
|
||||
|
||||
- The code already uses highly optimized libraries (e.g., NumPy with MKL, cuBLAS, cuDNN)
|
||||
- Functions have variable input types or shapes that prevent effective compilation
|
||||
- The compilation overhead exceeds the runtime savings for short-running functions
|
||||
- The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize
|
||||
|
||||
## Configuration
|
||||
|
||||
JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.
|
||||
|
||||
When running tests with coverage measurement, Codeflash temporarily disables JIT compilation to ensure accurate coverage data, then re-enables it for performance benchmarking.
|
||||
Loading…
Reference in a new issue