todo write more about the flags in torch/tensorflow and jax

This commit is contained in:
aseembits93 2026-01-22 17:20:52 -08:00
parent 454f20d6fc
commit 9fe6ef797a

View file

@ -8,49 +8,58 @@ keywords: ["JIT", "just-in-time", "numba", "pytorch", "tensorflow", "jax", "GPU"
# Support for Just-in-Time Compilation
Codeflash supports optimizing code using Just-in-Time (JIT) compilation. This allows Codeflash to suggest optimizations that leverage JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**.
Codeflash supports optimizing numerical code using Just-in-Time (JIT) compilation via leveraging JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**.
## Supported JIT Frameworks
Each framework uses different compilation strategies to accelerate Python code:
### Numba
### Numba (CPU Code)
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:
- **`@jit` / `@njit`** - General-purpose JIT compilation with `nopython` mode for removing Python interpreter overhead
- **`parallel=True`** - Enables automatic SIMD parallelization
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`@vectorize` / `@guvectorize`** - Creates NumPy universal functions (ufuncs)
- **`@cuda.jit`** - Compiles functions to run on NVIDIA GPUs
- **`@jit`** - General-purpose JIT compilation with optional flags.
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
### PyTorch
PyTorch provides multiple compilation approaches:
- **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs
- **`torch.jit.script`** - Compiles functions using TorchScript
- **`torch.jit.trace`** - Traces tensor operations to create optimized execution graphs
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
### TensorFlow
TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation:
- **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
### JAX
JAX captures side-effect-free operations and optimizes them:
- **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance
- **`noython=True`** - Compiles to machine code without falling back to the python interpreter.
- **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!).
- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag
- **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times.
## How Codeflash Optimizes with JIT
When Codeflash identifies a function that could benefit from JIT compilation, it:
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter type requirements
3. **Adds GPU synchronization calls** for accurate profiling when code runs on GPU, since GPU operations are inherently non-blocking
1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components.
2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements.
## Accurate Benchmarking with GPU Code