codeflash-agent/.tessl/tiles/tessl/pypi-jax/docs/device-memory.md
codeflash-ci-bot[bot] c249bcd0ce
chore: update tessl tiles 2026-04-23 (#35)
Co-authored-by: codeflash-ci-bot[bot] <codeflash-ci-bot[bot]@users.noreply.github.com>
2026-04-23 08:15:44 -05:00

15 KiB

Device and Memory Management

JAX provides comprehensive device management and distributed computing capabilities, enabling efficient use of CPUs, GPUs, and TPUs. This includes device placement, memory management, sharding for multi-device computation, and distributed array operations.

Core Imports

import jax
from jax import devices, device_put, make_mesh
from jax.sharding import NamedSharding, PartitionSpec as P

Capabilities

Device Discovery and Information

Query available devices and their properties for computation placement and resource management.

def devices(backend=None) -> list[Device]:
    """
    Get list of all available devices.
    
    Args:
        backend: Optional backend name ('cpu', 'gpu', 'tpu')
        
    Returns:
        List of available Device objects
    """

def local_devices(process_index=None, backend=None) -> list[Device]:
    """
    Get list of devices local to current process.
    
    Args:
        process_index: Process index (None for current process)
        backend: Optional backend name
        
    Returns:
        List of local Device objects
    """

def device_count(backend=None) -> int:
    """
    Get total number of devices across all processes.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Total device count
    """

def local_device_count(backend=None) -> int:
    """
    Get number of devices on current process.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Local device count
    """

def host_count(backend=None) -> int:
    """
    Get number of hosts in distributed computation.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Host count
    """

def host_id(backend=None) -> int:
    """
    Get ID of current host.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Current host ID
    """

def host_ids(backend=None) -> list[int]:
    """
    Get list of all host IDs.
    
    Args:
        backend: Optional backend name
        
    Returns:
        List of host IDs
    """

def process_count(backend=None) -> int:
    """
    Get number of processes in distributed computation.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Process count
    """

def process_index(backend=None) -> int:
    """
    Get index of current process.
    
    Args:
        backend: Optional backend name
        
    Returns:
        Current process index
    """

def process_indices(backend=None) -> list[int]:
    """
    Get list of all process indices.
    
    Args:
        backend: Optional backend name
        
    Returns:
        List of process indices
    """

def default_backend() -> str:
    """
    Get name of default backend.
    
    Returns:
        Default backend name string
    """

Device Placement and Data Movement

Control where computations run and move data between devices and host memory.

def device_put(x, device=None, src=None) -> Array:
    """
    Move array to specified device.
    
    Args:
        x: Array or array-like object to move
        device: Target device (None for default device)
        src: Source device for the transfer
        
    Returns:
        Array placed on target device
    """

def device_put_sharded(
    sharded_values: list, 
    devices: list[Device],
    indices=None
) -> Array:
    """
    Create sharded array from per-device values.
    
    Args:
        sharded_values: List of arrays, one per device
        devices: List of target devices
        indices: Optional sharding indices
        
    Returns:
        Distributed array sharded across devices
    """

def device_put_replicated(x, devices: list[Device]) -> Array:
    """
    Replicate array across multiple devices.
    
    Args:
        x: Array to replicate
        devices: List of target devices
        
    Returns:
        Array replicated across all specified devices
    """

def device_get(x) -> Any:
    """
    Move array from device to host memory as NumPy array.
    
    Args:
        x: Array to move to host
        
    Returns:
        NumPy array in host memory
    """

def copy_to_host_async(x) -> Any:
    """
    Asynchronously copy array to host memory.
    
    Args:
        x: Array to copy
        
    Returns:
        Future-like object for async copy
    """

def block_until_ready(x) -> Array:
    """
    Block until array computation is complete and ready.
    
    Args:
        x: Array to wait for
        
    Returns:
        The same array, guaranteed to be ready
    """

Usage examples:

# Check available devices
all_devices = jax.devices()
print(f"Available devices: {all_devices}")
print(f"Device count: {jax.device_count()}")

# Move data to specific device
cpu_data = jnp.array([1, 2, 3, 4])
if jax.devices('gpu'):
    gpu_data = jax.device_put(cpu_data, jax.devices('gpu')[0])
    print(f"Data is on: {gpu_data.device()}")

# Move back to host
host_data = jax.device_get(gpu_data)  # Returns NumPy array

# Explicit device placement in computations
with jax.default_device(jax.devices('cpu')[0]):
    cpu_result = jnp.sum(jnp.array([1, 2, 3]))

Sharding and Distributed Arrays

Define how arrays are distributed across multiple devices for parallel computation.

class NamedSharding:
    """
    Sharding specification using named mesh axes.
    
    Defines how arrays are partitioned across devices using logical axis names.
    """
    
    def __init__(self, mesh, spec):
        """
        Create named sharding specification.
        
        Args:
            mesh: Device mesh with named axes
            spec: Partition specification (PartitionSpec)
        """
        self.mesh = mesh
        self.spec = spec

class PartitionSpec:
    """
    Specification for how to partition array dimensions across mesh axes.
    
    Use P(axis_names...) to create partition specifications.
    """
    pass

# Alias for PartitionSpec  
P = PartitionSpec

def make_mesh(mesh_shape, axis_names) -> Mesh:
    """
    Create device mesh for distributed computation.
    
    Args:
        mesh_shape: Shape of device mesh (tuple of integers)
        axis_names: Names for mesh axes (tuple of strings)
        
    Returns:
        Mesh object representing device layout
    """

class Mesh:
    """Device mesh for distributed computation."""
    devices: Array  # Device array in mesh shape
    axis_names: tuple[str, ...]  # Names of mesh axes
    
    @property
    def shape(self) -> dict[str, int]:
        """Dictionary mapping axis names to sizes."""
        
    @property 
    def size(self) -> int:
        """Total number of devices in mesh."""

def make_array_from_single_device_arrays(
    arrays: list[Array],
    sharding: Sharding
) -> Array:
    """
    Create distributed array from per-device arrays.
    
    Args:
        arrays: List of arrays on different devices
        sharding: Sharding specification
        
    Returns:
        Distributed array with specified sharding
    """

def make_array_from_callback(
    shape: tuple[int, ...],
    sharding: Sharding, 
    data_callback: Callable
) -> Array:
    """
    Create distributed array using callback function.
    
    Args:
        shape: Global array shape
        sharding: Sharding specification  
        data_callback: Function to generate data for each shard
        
    Returns:
        Distributed array created from callback
    """

def make_array_from_process_local_data(
    sharding: Sharding,
    local_data: Array
) -> Array:
    """
    Create distributed array from process-local data.
    
    Args:
        sharding: Sharding specification
        local_data: Data local to current process
        
    Returns:
        Distributed array assembled from local data
    """

Sharded Computation

Execute computations on sharded arrays with explicit control over parallelization.

def shard_map(
    f: Callable,
    mesh: Mesh,
    in_specs,
    out_specs,
    check_rep=True
) -> Callable:
    """
    Transform function to operate on sharded arrays.
    
    Args:
        f: Function to transform
        mesh: Device mesh for computation
        in_specs: Input sharding specifications
        out_specs: Output sharding specifications  
        check_rep: Whether to check for replication consistency
        
    Returns:
        Function that operates on globally sharded arrays
    """

# Alias for shard_map
smap = shard_map

def with_sharding_constraint(x, sharding) -> Array:
    """
    Add sharding constraint to array.
    
    Args:
        x: Input array
        sharding: Desired sharding specification
        
    Returns:
        Array with sharding constraint applied
    """

Usage examples:

# Create 2x2 device mesh
devices_array = jnp.array(jax.devices()[:4]).reshape(2, 2)
mesh = jax.make_mesh((2, 2), ('data', 'model'))

# Define sharding specifications
data_sharding = NamedSharding(mesh, P('data', None))  # Shard first axis across 'data'
model_sharding = NamedSharding(mesh, P(None, 'model'))  # Shard second axis across 'model'
replicated_sharding = NamedSharding(mesh, P())  # Replicated across all devices

# Create sharded arrays
x = jax.random.normal(jax.random.key(0), (8, 4))
x_sharded = jax.device_put(x, data_sharding)

weights = jax.random.normal(jax.random.key(1), (4, 8))
weights_sharded = jax.device_put(weights, model_sharding)

# Computation with sharded arrays automatically parallelized  
@jax.jit
def matmul_fn(x, w):
    return x @ w

result = matmul_fn(x_sharded, weights_sharded)  # Automatically sharded computation

# Explicit sharding control
def single_device_fn(x_shard, w_shard):
    return x_shard @ w_shard

parallel_fn = jax.shard_map(
    single_device_fn,
    mesh=mesh,
    in_specs=(P('data', None), P(None, 'model')),
    out_specs=P('data', 'model')
)

result = parallel_fn(x_sharded, weights_sharded)

Memory Management

Control memory usage and optimize performance through explicit memory management.

def live_arrays() -> list[Array]:
    """
    Get list of arrays currently alive in memory.
    
    Returns:
        List of live Array objects
    """

def clear_caches() -> None:
    """
    Clear JAX's internal caches to free memory.
    
    Clears JIT compilation cache, device buffer cache, and other internal caches.
    """

Configuration and Backend Management

Configure device behavior and backend selection.

# Configuration through jax.config
jax.config.update('jax_platform_name', 'cpu')  # Force CPU backend
jax.config.update('jax_platform_name', 'gpu')  # Force GPU backend  
jax.config.update('jax_platform_name', 'tpu')  # Force TPU backend

# Transfer guards to catch unintentional device transfers
jax.config.update('jax_transfer_guard', 'allow')    # Default: allow all transfers
jax.config.update('jax_transfer_guard', 'log')      # Log transfers  
jax.config.update('jax_transfer_guard', 'disallow') # Disallow transfers
jax.config.update('jax_transfer_guard', 'log_explicit_device_put') # Log explicit transfers

# Default device configuration
jax.config.update('jax_default_device', jax.devices('gpu')[0])  # Set default device

Array and Device Properties

Inspect array placement and device properties.

# Array device methods
array.device() -> Device  # Get device containing array
array.devices() -> set[Device]  # Get all devices for distributed array
array.sharding -> Sharding  # Get array's sharding specification
array.is_fully_replicated -> bool  # Check if array is replicated
array.is_fully_addressable -> bool  # Check if array is fully addressable

# Device properties
class Device:
    """Device object representing compute accelerator."""
    
    platform: str  # Platform name ('cpu', 'gpu', 'tpu')
    device_kind: str  # Device kind string  
    id: int  # Device ID within platform
    host_id: int  # Host ID containing device
    process_index: int  # Process index containing device
    
    def __str__(self) -> str: ...
    def __repr__(self) -> str: ...

Advanced Usage Patterns

Multi-Device Training

# Setup for data-parallel training
def create_train_setup(num_devices):
    # Create mesh for data parallelism
    mesh = jax.make_mesh((num_devices,), ('batch',))
    
    # Sharding specifications
    batch_sharding = NamedSharding(mesh, P('batch'))  # Batch dimension sharded
    replicated_sharding = NamedSharding(mesh, P())    # Parameters replicated
    
    return mesh, batch_sharding, replicated_sharding

def distributed_train_step(params, batch, optimizer_state):
    # All arrays should already have appropriate sharding
    grads = jax.grad(loss_fn)(params, batch)
    
    # Update step automatically uses sharding from inputs
    new_params, new_state = optimizer.update(grads, optimizer_state, params)
    return new_params, new_state

# JIT compile with sharding
distributed_train_step = jax.jit(
    distributed_train_step,
    in_shardings=(replicated_sharding, batch_sharding, replicated_sharding),
    out_shardings=(replicated_sharding, replicated_sharding)
)

Model Parallelism

# Setup for model-parallel computation
def create_model_parallel_setup():
    # 2D mesh: batch x model dimensions
    mesh = jax.make_mesh((2, 4), ('batch', 'model'))
    
    # Different sharding strategies
    input_sharding = NamedSharding(mesh, P('batch', None))
    weight_sharding = NamedSharding(mesh, P(None, 'model'))  
    output_sharding = NamedSharding(mesh, P('batch', 'model'))
    
    return mesh, input_sharding, weight_sharding, output_sharding

def model_parallel_layer(x, weights):
    # Matrix multiply with different sharding patterns
    return x @ weights  # JAX handles the communication automatically

# Shard arrays according to strategy
x = jax.device_put(x, input_sharding)
weights = jax.device_put(weights, weight_sharding)
result = model_parallel_layer(x, weights)  # Result has output_sharding

Memory-Efficient Inference

def memory_efficient_inference(model_fn, large_input):
    # Process in chunks to manage memory
    chunk_size = 1000
    chunks = [large_input[i:i+chunk_size] for i in range(0, len(large_input), chunk_size)]
    
    results = []
    for chunk in chunks:
        # Move to device, compute, move back to host
        device_chunk = jax.device_put(chunk)
        device_result = model_fn(device_chunk)
        host_result = jax.device_get(device_result)
        results.append(host_result)
        
        # Optional: clear caches to free memory
        jax.clear_caches()
    
    return jnp.concatenate(results)

Cross-Device Communication Patterns

# Collective operations using pmap
@jax.pmap
def allreduce_example(x):
    # Sum across all devices
    return jax.lax.psum(x, axis_name='batch')

@jax.pmap  
def allgather_example(x):
    # Gather from all devices
    return jax.lax.all_gather(x, axis_name='batch')

# Use with replicated data
replicated_data = jax.device_put_replicated(data, jax.devices())
summed_result = allreduce_example(replicated_data)
gathered_result = allgather_example(replicated_data)