85 lines
4.3 KiB
Text
85 lines
4.3 KiB
Text
---
|
|
title: "Support for Just-in-Time Compilation"
|
|
description: "Learn how Codeflash optimizes code using JIT compilation with Numba, PyTorch, TensorFlow, and JAX"
|
|
icon: "bolt"
|
|
sidebarTitle: "JIT Compilation"
|
|
keywords: ["JIT", "just-in-time", "numba", "pytorch", "tensorflow", "jax", "GPU", "CUDA", "compilation", "performance"]
|
|
---
|
|
|
|
# Support for Just-in-Time Compilation
|
|
|
|
Codeflash supports optimizing code using Just-in-Time (JIT) compilation. This allows Codeflash to suggest optimizations that leverage JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**.
|
|
|
|
## Supported JIT Frameworks
|
|
|
|
Each framework uses different compilation strategies to accelerate Python code:
|
|
|
|
### Numba
|
|
|
|
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:
|
|
|
|
- **`@jit` / `@njit`** - General-purpose JIT compilation with `nopython` mode for removing Python interpreter overhead
|
|
- **`parallel=True`** - Enables automatic SIMD parallelization
|
|
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
|
|
- **`@vectorize` / `@guvectorize`** - Creates NumPy universal functions (ufuncs)
|
|
- **`@cuda.jit`** - Compiles functions to run on NVIDIA GPUs
|
|
|
|
### PyTorch
|
|
|
|
PyTorch provides multiple compilation approaches:
|
|
|
|
- **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs
|
|
- **`torch.jit.script`** - Compiles functions using TorchScript
|
|
- **`torch.jit.trace`** - Traces tensor operations to create optimized execution graphs
|
|
|
|
### TensorFlow
|
|
|
|
TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation:
|
|
|
|
- **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA
|
|
|
|
### JAX
|
|
|
|
JAX captures side-effect-free operations and optimizes them:
|
|
|
|
- **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance
|
|
|
|
## How Codeflash Optimizes with JIT
|
|
|
|
When Codeflash identifies a function that could benefit from JIT compilation, it:
|
|
|
|
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components
|
|
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter type requirements
|
|
3. **Adds GPU synchronization calls** for accurate profiling when code runs on GPU, since GPU operations are inherently non-blocking
|
|
|
|
## Accurate Benchmarking with GPU Code
|
|
|
|
Since GPU operations execute asynchronously, Codeflash automatically inserts synchronization barriers before measuring performance. This ensures timing measurements reflect actual computation time rather than just the time to queue operations:
|
|
|
|
- **PyTorch**: Uses `torch.cuda.synchronize()` or `torch.mps.synchronize()` depending on the device
|
|
- **JAX**: Uses `jax.block_until_ready()` to wait for computation to complete
|
|
- **TensorFlow**: Uses `tf.test.experimental.sync_devices()` for device synchronization
|
|
|
|
## When JIT Compilation Helps
|
|
|
|
JIT compilation is most effective for:
|
|
|
|
- Numerical computations with loops that can't be easily vectorized
|
|
- Custom algorithms not covered by existing optimized libraries
|
|
- Functions that are called repeatedly with consistent input types
|
|
- Code that benefits from hardware-specific optimizations (SIMD, GPU acceleration)
|
|
|
|
## When JIT Compilation May Not Help
|
|
|
|
JIT compilation may not provide speedups when:
|
|
|
|
- The code already uses highly optimized libraries (e.g., NumPy with MKL, cuBLAS, cuDNN)
|
|
- Functions have variable input types or shapes that prevent effective compilation
|
|
- The compilation overhead exceeds the runtime savings for short-running functions
|
|
- The code relies heavily on Python objects or dynamic features that JIT compilers can't optimize
|
|
|
|
## Configuration
|
|
|
|
JIT compilation support is **enabled automatically** in Codeflash. You don't need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.
|
|
|
|
When running tests with coverage measurement, Codeflash temporarily disables JIT compilation to ensure accurate coverage data, then re-enables it for performance benchmarking.
|