todo write more about the flags in torch/tensorflow and jax
This commit is contained in:
parent
454f20d6fc
commit
9fe6ef797a
1 changed files with 21 additions and 12 deletions
|
|
@ -8,49 +8,58 @@ keywords: ["JIT", "just-in-time", "numba", "pytorch", "tensorflow", "jax", "GPU"
|
|||
|
||||
# Support for Just-in-Time Compilation
|
||||
|
||||
Codeflash supports optimizing code using Just-in-Time (JIT) compilation. This allows Codeflash to suggest optimizations that leverage JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**.
|
||||
Codeflash supports optimizing numerical code using Just-in-Time (JIT) compilation via leveraging JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**.
|
||||
|
||||
## Supported JIT Frameworks
|
||||
|
||||
Each framework uses different compilation strategies to accelerate Python code:
|
||||
|
||||
### Numba
|
||||
### Numba (CPU Code)
|
||||
|
||||
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:
|
||||
|
||||
- **`@jit` / `@njit`** - General-purpose JIT compilation with `nopython` mode for removing Python interpreter overhead
|
||||
- **`parallel=True`** - Enables automatic SIMD parallelization
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`@vectorize` / `@guvectorize`** - Creates NumPy universal functions (ufuncs)
|
||||
- **`@cuda.jit`** - Compiles functions to run on NVIDIA GPUs
|
||||
- **`@jit`** - General-purpose JIT compilation with optional flags.
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
|
||||
### PyTorch
|
||||
|
||||
PyTorch provides multiple compilation approaches:
|
||||
|
||||
- **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs
|
||||
- **`torch.jit.script`** - Compiles functions using TorchScript
|
||||
- **`torch.jit.trace`** - Traces tensor operations to create optimized execution graphs
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
|
||||
### TensorFlow
|
||||
|
||||
TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation:
|
||||
|
||||
- **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
|
||||
### JAX
|
||||
|
||||
JAX captures side-effect-free operations and optimizes them:
|
||||
|
||||
- **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance
|
||||
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
|
||||
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
|
||||
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
||||
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
|
||||
|
||||
## How Codeflash Optimizes with JIT
|
||||
|
||||
When Codeflash identifies a function that could benefit from JIT compilation, it:
|
||||
|
||||
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components
|
||||
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter type requirements
|
||||
3. **Adds GPU synchronization calls** for accurate profiling when code runs on GPU, since GPU operations are inherently non-blocking
|
||||
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components.
|
||||
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements.
|
||||
|
||||
## Accurate Benchmarking with GPU Code
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue