From 9fe6ef797aa8fd8d33a72530027542eda8fff35c Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Thu, 22 Jan 2026 17:20:52 -0800 Subject: [PATCH] todo write more about the flags in torch/tensorflow and jax --- docs/support-for-jit/index.mdx | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/docs/support-for-jit/index.mdx b/docs/support-for-jit/index.mdx index 3dfc68917..9aa91a6a4 100644 --- a/docs/support-for-jit/index.mdx +++ b/docs/support-for-jit/index.mdx @@ -8,49 +8,58 @@ keywords: ["JIT", "just-in-time", "numba", "pytorch", "tensorflow", "jax", "GPU" # Support for Just-in-Time Compilation -Codeflash supports optimizing code using Just-in-Time (JIT) compilation. This allows Codeflash to suggest optimizations that leverage JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**. +Codeflash supports optimizing numerical code using Just-in-Time (JIT) compilation via leveraging JIT compilers from popular frameworks including **Numba**, **PyTorch**, **TensorFlow**, and **JAX**. ## Supported JIT Frameworks Each framework uses different compilation strategies to accelerate Python code: -### Numba +### Numba (CPU Code) Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use: -- **`@jit` / `@njit`** - General-purpose JIT compilation with `nopython` mode for removing Python interpreter overhead -- **`parallel=True`** - Enables automatic SIMD parallelization -- **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag -- **`@vectorize` / `@guvectorize`** - Creates NumPy universal functions (ufuncs) -- **`@cuda.jit`** - Compiles functions to run on NVIDIA GPUs +- **`@jit`** - General-purpose JIT compilation with optional flags. + - **`noython=True`** - Compiles to machine code without falling back to the python interpreter. + - **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!). + - **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag + - **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times. ### PyTorch PyTorch provides multiple compilation approaches: - **`torch.compile()`** - The recommended compilation API that uses TorchDynamo to trace operations and create optimized CUDA graphs -- **`torch.jit.script`** - Compiles functions using TorchScript -- **`torch.jit.trace`** - Traces tensor operations to create optimized execution graphs + - **`noython=True`** - Compiles to machine code without falling back to the python interpreter. + - **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!). + - **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag + - **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times. ### TensorFlow TensorFlow uses the XLA (Accelerated Linear Algebra) backend for JIT compilation: - **`@tf.function`** - Compiles Python functions into optimized TensorFlow graphs using XLA + - **`noython=True`** - Compiles to machine code without falling back to the python interpreter. + - **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!). + - **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag + - **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times. ### JAX JAX captures side-effect-free operations and optimizes them: - **`@jax.jit`** - JIT compiles functions using XLA, with automatic operation fusion for improved performance + - **`noython=True`** - Compiles to machine code without falling back to the python interpreter. + - **`parallel=True`** - Enables automatic thread-level parallelization of the function across multiple CPU cores (no GIL!). + - **`fastmath=True`** - Uses aggressive floating-point optimizations via LLVM's fastmath flag + - **`cache=True`** - Numba writes the result of function compilation to disk which significantly reduces future compilation times. ## How Codeflash Optimizes with JIT When Codeflash identifies a function that could benefit from JIT compilation, it: -1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components -2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter type requirements -3. **Adds GPU synchronization calls** for accurate profiling when code runs on GPU, since GPU operations are inherently non-blocking +1. **Rewrites the code** in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components. +2. **Generates appropriate tests** that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements. ## Accurate Benchmarking with GPU Code