# Experimental Features JAX experimental features provide access to cutting-edge capabilities, performance optimizations, and research functionality through `jax.experimental`. These features may change or be moved to the main JAX API in future versions. **Warning**: Experimental APIs may change without notice between JAX versions. Use with caution in production code. ## Core Imports ```python import jax.experimental as jex from jax.experimental import io_callback, enable_x64 ``` ## Capabilities ### Precision Control Control floating-point precision globally across JAX computations. ```python { .api } def enable_x64(enable: bool = True) -> None: """ Enable or disable 64-bit floating point precision. Args: enable: Whether to enable 64-bit precision (default: True) Note: This sets jax_enable_x64 config flag globally """ def disable_x64() -> None: """ Disable 64-bit floating point precision. Convenience function equivalent to enable_x64(False). """ ``` Usage examples: ```python # Enable double precision jax.experimental.enable_x64() x = jnp.array(1.0) # Now defaults to float64 instead of float32 print(x.dtype) # dtype('float64') # Disable double precision jax.experimental.disable_x64() y = jnp.array(1.0) # Back to float32 print(y.dtype) # dtype('float32') ``` ### I/O and Callbacks Enable host callbacks for I/O operations and side effects within JAX computations. ```python { .api } def io_callback( callback: Callable, result_shape_dtypes, *args, sharding=None, vmap_method=None, ordered=False, **kwargs ) -> Any: """ Call host function from within JAX computation with I/O side effects. Args: callback: Host function to call (should be pure except for I/O) result_shape_dtypes: Shape and dtype specification for callback result args: Arguments to pass to callback sharding: Sharding specification for result vmap_method: How to handle vmapping ('sequential', 'expand_dims', etc.) ordered: Whether to maintain call ordering across devices kwargs: Additional keyword arguments for callback Returns: Result of callback with specified shape and dtype """ ``` Usage examples: ```python # Logging during computation (debugging) def log_value(x, step): print(f"Step {step}: value = {x}") return x @jax.jit def training_step(x, step): # Log intermediate values during training x = jax.experimental.io_callback( log_value, jax.ShapeDtypeStruct(x.shape, x.dtype), x, step ) return x * 2 # File I/O during computation def save_checkpoint(params, step): import pickle with open(f'checkpoint_{step}.pkl', 'wb') as f: pickle.dump(params, f) return step @jax.jit def train_with_checkpointing(params, data, step): # Training computation loss = compute_loss(params, data) grads = jax.grad(compute_loss)(params, data) new_params = update_params(params, grads) # Save checkpoint every 100 steps step = jax.experimental.io_callback( save_checkpoint, jax.ShapeDtypeStruct((), jnp.int32), new_params, step ) return new_params, loss ``` ### Advanced Differentiation Experimental differentiation features and optimizations. ```python { .api } def saved_input_vjp(f, *primals) -> tuple[Any, Callable]: """ Vector-Jacobian product with saved inputs for memory efficiency. Args: f: Function to differentiate primals: Input values Returns: Tuple of (primal_out, vjp_fun) where vjp_fun has access to saved inputs """ # Alias for saved_input_vjp si_vjp = saved_input_vjp ``` Usage example: ```python def expensive_function(x, y): # Some expensive computation that we want to differentiate z = jnp.exp(x) + jnp.sin(y) return jnp.sum(z ** 2) # Use saved input VJP for memory efficiency x, y = jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]) primal_out, vjp_fn = jax.experimental.saved_input_vjp(expensive_function, x, y) # Compute VJP with cotangent cotangent = 1.0 x_grad, y_grad = vjp_fn(cotangent) ``` ### Extended Array Types Experimental array types and extended functionality. ```python { .api } class EArray: """ Extended array type with additional metadata and functionality. Experimental array type that may include additional features beyond standard JAX arrays. """ pass class MutableArray: """ Experimental mutable array type for specific use cases. Warning: Breaks JAX's functional programming model. Use carefully. """ pass def mutable_array(init_val) -> MutableArray: """ Create mutable array from initial value. Args: init_val: Initial array value Returns: MutableArray that can be modified in-place """ ``` ### Type System Extensions Experimental extensions to JAX's type system. ```python { .api } def primal_tangent_dtype(primal_dtype, tangent_dtype=None): """ Create dtype for primal-tangent pairs in forward-mode AD. Args: primal_dtype: Data type for primal values tangent_dtype: Data type for tangent values (defaults to primal_dtype) Returns: Combined dtype for primal-tangent computation """ ``` ### Compilation and Performance Experimental compilation features and performance optimizations. ```python { .api } # Compilation control def disable_jit_cache() -> None: """Disable JIT compilation cache for debugging.""" def enable_jit_cache() -> None: """Re-enable JIT compilation cache.""" # Performance monitoring def compilation_cache_stats() -> dict: """Get statistics about JIT compilation cache.""" def clear_compilation_cache() -> None: """Clear JIT compilation cache.""" ``` ### Hardware-Specific Features Experimental features for specific hardware accelerators. ```python { .api } # TPU-specific features class TPUMemoryFraction: """Control TPU memory usage fraction.""" def set_tpu_memory_fraction(fraction: float) -> None: """ Set fraction of TPU memory to use. Args: fraction: Memory fraction (0.0 to 1.0) """ # GPU-specific features def gpu_memory_stats() -> dict: """Get GPU memory usage statistics.""" def set_gpu_memory_growth(enable: bool) -> None: """ Enable/disable GPU memory growth. Args: enable: Whether to enable incremental memory allocation """ ``` ### Automatic Mixed Precision Experimental automatic mixed precision for training acceleration. ```python { .api } class AutoMixedPrecision: """Automatic mixed precision policy for training.""" def __init__(self, policy='float16'): """ Initialize AMP policy. Args: policy: Precision policy ('float16', 'bfloat16', etc.) """ self.policy = policy def __call__(self, fn): """Apply AMP to function.""" pass def amp_policy(policy_name: str) -> AutoMixedPrecision: """ Create automatic mixed precision policy. Args: policy_name: Name of precision policy Returns: AMP policy object """ ``` ### Distributed Computing Extensions Experimental distributed computing features beyond standard pmap/shard_map. ```python { .api } def multi_host_utils(): """Utilities for multi-host distributed computation.""" pass class GlobalDeviceArray: """ Experimental global device array for large-scale distributed computation. Represents arrays that span multiple hosts in distributed setting. """ pass def create_global_device_array( shape, dtype, mesh, partition_spec ) -> GlobalDeviceArray: """ Create global device array across distributed system. Args: shape: Global array shape dtype: Array data type mesh: Device mesh specification partition_spec: How to partition array Returns: Global device array """ ``` ### Research and Prototype Features Cutting-edge research features that may be highly experimental. ```python { .api } # Sparsity support class SparseArray: """Experimental sparse array support.""" pass def sparse_ops(): """Sparse operations module (highly experimental).""" pass # Quantization support def quantized_dot(lhs, rhs, **kwargs): """Experimental quantized matrix multiplication.""" pass def quantization_utils(): """Utilities for quantized computation.""" pass # Custom operators def custom_op_builder(): """Builder for custom XLA operations.""" pass # Advanced compilation def ahead_of_time_compile(fn, *args, **kwargs): """Ahead-of-time compilation (experimental).""" pass ``` ### Debugging and Profiling Experimental debugging and profiling tools. ```python { .api } def debug_callback(callback, *args, **kwargs): """ Debug callback that doesn't affect computation graph. Args: callback: Debug function to call args: Arguments to callback kwargs: Keyword arguments to callback """ def trace_function(fn): """ Trace function execution for debugging. Args: fn: Function to trace Returns: Traced version of function """ def memory_profiler(): """Memory profiling utilities.""" pass def computation_graph_visualizer(): """Tools for visualizing computation graphs.""" pass ``` ## Migration Patterns When experimental features graduate to main JAX API: ```python # Old experimental usage from jax.experimental import feature_name # New main API usage (after graduation) from jax import feature_name # Or sometimes moves to different module from jax.some_module import feature_name ``` ## Usage Guidelines ### Best Practices for Experimental Features ```python # 1. Version pinning when using experimental features # requirements.txt: jax==0.7.1 # Pin exact version # 2. Graceful fallbacks try: from jax.experimental import new_feature use_experimental = True except ImportError: use_experimental = False def my_function(x): if use_experimental: return new_feature.optimized_op(x) else: return traditional_op(x) # 3. Feature flags for experimental code USE_EXPERIMENTAL_AMP = False if USE_EXPERIMENTAL_AMP: amp_policy = jax.experimental.amp_policy('float16') train_fn = amp_policy(train_fn) # 4. Documentation and warnings def experimental_model_fn(x): """ Model function using experimental JAX features. Warning: Uses jax.experimental.* APIs that may change. Tested with JAX v0.7.1. """ # Implementation using experimental features pass ``` ### Testing Experimental Features ```python import pytest # Skip tests if experimental feature not available @pytest.mark.skipif( not hasattr(jax.experimental, 'new_feature'), reason="Experimental feature not available" ) def test_experimental_feature(): # Test experimental functionality pass # Conditional testing based on JAX version import jax jax_version = tuple(map(int, jax.__version__.split('.')[:2])) @pytest.mark.skipif( jax_version < (0, 7), reason="Feature requires JAX >= 0.7" ) def test_version_dependent_feature(): # Test version-dependent experimental feature pass ```