15 KiB
Device and Memory Management
JAX provides comprehensive device management and distributed computing capabilities, enabling efficient use of CPUs, GPUs, and TPUs. This includes device placement, memory management, sharding for multi-device computation, and distributed array operations.
Core Imports
import jax
from jax import devices, device_put, make_mesh
from jax.sharding import NamedSharding, PartitionSpec as P
Capabilities
Device Discovery and Information
Query available devices and their properties for computation placement and resource management.
def devices(backend=None) -> list[Device]:
"""
Get list of all available devices.
Args:
backend: Optional backend name ('cpu', 'gpu', 'tpu')
Returns:
List of available Device objects
"""
def local_devices(process_index=None, backend=None) -> list[Device]:
"""
Get list of devices local to current process.
Args:
process_index: Process index (None for current process)
backend: Optional backend name
Returns:
List of local Device objects
"""
def device_count(backend=None) -> int:
"""
Get total number of devices across all processes.
Args:
backend: Optional backend name
Returns:
Total device count
"""
def local_device_count(backend=None) -> int:
"""
Get number of devices on current process.
Args:
backend: Optional backend name
Returns:
Local device count
"""
def host_count(backend=None) -> int:
"""
Get number of hosts in distributed computation.
Args:
backend: Optional backend name
Returns:
Host count
"""
def host_id(backend=None) -> int:
"""
Get ID of current host.
Args:
backend: Optional backend name
Returns:
Current host ID
"""
def host_ids(backend=None) -> list[int]:
"""
Get list of all host IDs.
Args:
backend: Optional backend name
Returns:
List of host IDs
"""
def process_count(backend=None) -> int:
"""
Get number of processes in distributed computation.
Args:
backend: Optional backend name
Returns:
Process count
"""
def process_index(backend=None) -> int:
"""
Get index of current process.
Args:
backend: Optional backend name
Returns:
Current process index
"""
def process_indices(backend=None) -> list[int]:
"""
Get list of all process indices.
Args:
backend: Optional backend name
Returns:
List of process indices
"""
def default_backend() -> str:
"""
Get name of default backend.
Returns:
Default backend name string
"""
Device Placement and Data Movement
Control where computations run and move data between devices and host memory.
def device_put(x, device=None, src=None) -> Array:
"""
Move array to specified device.
Args:
x: Array or array-like object to move
device: Target device (None for default device)
src: Source device for the transfer
Returns:
Array placed on target device
"""
def device_put_sharded(
sharded_values: list,
devices: list[Device],
indices=None
) -> Array:
"""
Create sharded array from per-device values.
Args:
sharded_values: List of arrays, one per device
devices: List of target devices
indices: Optional sharding indices
Returns:
Distributed array sharded across devices
"""
def device_put_replicated(x, devices: list[Device]) -> Array:
"""
Replicate array across multiple devices.
Args:
x: Array to replicate
devices: List of target devices
Returns:
Array replicated across all specified devices
"""
def device_get(x) -> Any:
"""
Move array from device to host memory as NumPy array.
Args:
x: Array to move to host
Returns:
NumPy array in host memory
"""
def copy_to_host_async(x) -> Any:
"""
Asynchronously copy array to host memory.
Args:
x: Array to copy
Returns:
Future-like object for async copy
"""
def block_until_ready(x) -> Array:
"""
Block until array computation is complete and ready.
Args:
x: Array to wait for
Returns:
The same array, guaranteed to be ready
"""
Usage examples:
# Check available devices
all_devices = jax.devices()
print(f"Available devices: {all_devices}")
print(f"Device count: {jax.device_count()}")
# Move data to specific device
cpu_data = jnp.array([1, 2, 3, 4])
if jax.devices('gpu'):
gpu_data = jax.device_put(cpu_data, jax.devices('gpu')[0])
print(f"Data is on: {gpu_data.device()}")
# Move back to host
host_data = jax.device_get(gpu_data) # Returns NumPy array
# Explicit device placement in computations
with jax.default_device(jax.devices('cpu')[0]):
cpu_result = jnp.sum(jnp.array([1, 2, 3]))
Sharding and Distributed Arrays
Define how arrays are distributed across multiple devices for parallel computation.
class NamedSharding:
"""
Sharding specification using named mesh axes.
Defines how arrays are partitioned across devices using logical axis names.
"""
def __init__(self, mesh, spec):
"""
Create named sharding specification.
Args:
mesh: Device mesh with named axes
spec: Partition specification (PartitionSpec)
"""
self.mesh = mesh
self.spec = spec
class PartitionSpec:
"""
Specification for how to partition array dimensions across mesh axes.
Use P(axis_names...) to create partition specifications.
"""
pass
# Alias for PartitionSpec
P = PartitionSpec
def make_mesh(mesh_shape, axis_names) -> Mesh:
"""
Create device mesh for distributed computation.
Args:
mesh_shape: Shape of device mesh (tuple of integers)
axis_names: Names for mesh axes (tuple of strings)
Returns:
Mesh object representing device layout
"""
class Mesh:
"""Device mesh for distributed computation."""
devices: Array # Device array in mesh shape
axis_names: tuple[str, ...] # Names of mesh axes
@property
def shape(self) -> dict[str, int]:
"""Dictionary mapping axis names to sizes."""
@property
def size(self) -> int:
"""Total number of devices in mesh."""
def make_array_from_single_device_arrays(
arrays: list[Array],
sharding: Sharding
) -> Array:
"""
Create distributed array from per-device arrays.
Args:
arrays: List of arrays on different devices
sharding: Sharding specification
Returns:
Distributed array with specified sharding
"""
def make_array_from_callback(
shape: tuple[int, ...],
sharding: Sharding,
data_callback: Callable
) -> Array:
"""
Create distributed array using callback function.
Args:
shape: Global array shape
sharding: Sharding specification
data_callback: Function to generate data for each shard
Returns:
Distributed array created from callback
"""
def make_array_from_process_local_data(
sharding: Sharding,
local_data: Array
) -> Array:
"""
Create distributed array from process-local data.
Args:
sharding: Sharding specification
local_data: Data local to current process
Returns:
Distributed array assembled from local data
"""
Sharded Computation
Execute computations on sharded arrays with explicit control over parallelization.
def shard_map(
f: Callable,
mesh: Mesh,
in_specs,
out_specs,
check_rep=True
) -> Callable:
"""
Transform function to operate on sharded arrays.
Args:
f: Function to transform
mesh: Device mesh for computation
in_specs: Input sharding specifications
out_specs: Output sharding specifications
check_rep: Whether to check for replication consistency
Returns:
Function that operates on globally sharded arrays
"""
# Alias for shard_map
smap = shard_map
def with_sharding_constraint(x, sharding) -> Array:
"""
Add sharding constraint to array.
Args:
x: Input array
sharding: Desired sharding specification
Returns:
Array with sharding constraint applied
"""
Usage examples:
# Create 2x2 device mesh
devices_array = jnp.array(jax.devices()[:4]).reshape(2, 2)
mesh = jax.make_mesh((2, 2), ('data', 'model'))
# Define sharding specifications
data_sharding = NamedSharding(mesh, P('data', None)) # Shard first axis across 'data'
model_sharding = NamedSharding(mesh, P(None, 'model')) # Shard second axis across 'model'
replicated_sharding = NamedSharding(mesh, P()) # Replicated across all devices
# Create sharded arrays
x = jax.random.normal(jax.random.key(0), (8, 4))
x_sharded = jax.device_put(x, data_sharding)
weights = jax.random.normal(jax.random.key(1), (4, 8))
weights_sharded = jax.device_put(weights, model_sharding)
# Computation with sharded arrays automatically parallelized
@jax.jit
def matmul_fn(x, w):
return x @ w
result = matmul_fn(x_sharded, weights_sharded) # Automatically sharded computation
# Explicit sharding control
def single_device_fn(x_shard, w_shard):
return x_shard @ w_shard
parallel_fn = jax.shard_map(
single_device_fn,
mesh=mesh,
in_specs=(P('data', None), P(None, 'model')),
out_specs=P('data', 'model')
)
result = parallel_fn(x_sharded, weights_sharded)
Memory Management
Control memory usage and optimize performance through explicit memory management.
def live_arrays() -> list[Array]:
"""
Get list of arrays currently alive in memory.
Returns:
List of live Array objects
"""
def clear_caches() -> None:
"""
Clear JAX's internal caches to free memory.
Clears JIT compilation cache, device buffer cache, and other internal caches.
"""
Configuration and Backend Management
Configure device behavior and backend selection.
# Configuration through jax.config
jax.config.update('jax_platform_name', 'cpu') # Force CPU backend
jax.config.update('jax_platform_name', 'gpu') # Force GPU backend
jax.config.update('jax_platform_name', 'tpu') # Force TPU backend
# Transfer guards to catch unintentional device transfers
jax.config.update('jax_transfer_guard', 'allow') # Default: allow all transfers
jax.config.update('jax_transfer_guard', 'log') # Log transfers
jax.config.update('jax_transfer_guard', 'disallow') # Disallow transfers
jax.config.update('jax_transfer_guard', 'log_explicit_device_put') # Log explicit transfers
# Default device configuration
jax.config.update('jax_default_device', jax.devices('gpu')[0]) # Set default device
Array and Device Properties
Inspect array placement and device properties.
# Array device methods
array.device() -> Device # Get device containing array
array.devices() -> set[Device] # Get all devices for distributed array
array.sharding -> Sharding # Get array's sharding specification
array.is_fully_replicated -> bool # Check if array is replicated
array.is_fully_addressable -> bool # Check if array is fully addressable
# Device properties
class Device:
"""Device object representing compute accelerator."""
platform: str # Platform name ('cpu', 'gpu', 'tpu')
device_kind: str # Device kind string
id: int # Device ID within platform
host_id: int # Host ID containing device
process_index: int # Process index containing device
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
Advanced Usage Patterns
Multi-Device Training
# Setup for data-parallel training
def create_train_setup(num_devices):
# Create mesh for data parallelism
mesh = jax.make_mesh((num_devices,), ('batch',))
# Sharding specifications
batch_sharding = NamedSharding(mesh, P('batch')) # Batch dimension sharded
replicated_sharding = NamedSharding(mesh, P()) # Parameters replicated
return mesh, batch_sharding, replicated_sharding
def distributed_train_step(params, batch, optimizer_state):
# All arrays should already have appropriate sharding
grads = jax.grad(loss_fn)(params, batch)
# Update step automatically uses sharding from inputs
new_params, new_state = optimizer.update(grads, optimizer_state, params)
return new_params, new_state
# JIT compile with sharding
distributed_train_step = jax.jit(
distributed_train_step,
in_shardings=(replicated_sharding, batch_sharding, replicated_sharding),
out_shardings=(replicated_sharding, replicated_sharding)
)
Model Parallelism
# Setup for model-parallel computation
def create_model_parallel_setup():
# 2D mesh: batch x model dimensions
mesh = jax.make_mesh((2, 4), ('batch', 'model'))
# Different sharding strategies
input_sharding = NamedSharding(mesh, P('batch', None))
weight_sharding = NamedSharding(mesh, P(None, 'model'))
output_sharding = NamedSharding(mesh, P('batch', 'model'))
return mesh, input_sharding, weight_sharding, output_sharding
def model_parallel_layer(x, weights):
# Matrix multiply with different sharding patterns
return x @ weights # JAX handles the communication automatically
# Shard arrays according to strategy
x = jax.device_put(x, input_sharding)
weights = jax.device_put(weights, weight_sharding)
result = model_parallel_layer(x, weights) # Result has output_sharding
Memory-Efficient Inference
def memory_efficient_inference(model_fn, large_input):
# Process in chunks to manage memory
chunk_size = 1000
chunks = [large_input[i:i+chunk_size] for i in range(0, len(large_input), chunk_size)]
results = []
for chunk in chunks:
# Move to device, compute, move back to host
device_chunk = jax.device_put(chunk)
device_result = model_fn(device_chunk)
host_result = jax.device_get(device_result)
results.append(host_result)
# Optional: clear caches to free memory
jax.clear_caches()
return jnp.concatenate(results)
Cross-Device Communication Patterns
# Collective operations using pmap
@jax.pmap
def allreduce_example(x):
# Sum across all devices
return jax.lax.psum(x, axis_name='batch')
@jax.pmap
def allgather_example(x):
# Gather from all devices
return jax.lax.all_gather(x, axis_name='batch')
# Use with replicated data
replicated_data = jax.device_put_replicated(data, jax.devices())
summed_result = allreduce_example(replicated_data)
gathered_result = allgather_example(replicated_data)