codeflash-agent/.tessl/tiles/tessl/pypi-jax/docs/neural-networks.md

573 lines
13 KiB
Markdown
Raw Normal View History

# Neural Network Functions
JAX provides a comprehensive set of neural network functions through `jax.nn` including activation functions, normalization utilities, and attention mechanisms commonly used in machine learning and deep learning applications.
## Core Imports
```python
import jax.nn as jnn
from jax.nn import relu, sigmoid, softmax, gelu
```
## Capabilities
### ReLU and Variants
Rectified Linear Unit activations and their variants for introducing non-linearity while maintaining computational efficiency.
```python { .api }
def relu(x) -> Array:
"""
Rectified Linear Unit activation: max(0, x).
Args:
x: Input array
Returns:
Array with ReLU applied element-wise
"""
def relu6(x) -> Array:
"""
ReLU capped at 6: min(max(0, x), 6).
Args:
x: Input array
Returns:
Array with ReLU6 applied element-wise
"""
def leaky_relu(x, negative_slope=0.01) -> Array:
"""
Leaky ReLU: max(negative_slope * x, x).
Args:
x: Input array
negative_slope: Slope for negative values (default: 0.01)
Returns:
Array with Leaky ReLU applied element-wise
"""
def elu(x, alpha=1.0) -> Array:
"""
Exponential Linear Unit: x if x > 0 else alpha * (exp(x) - 1).
Args:
x: Input array
alpha: Scale for negative values (default: 1.0)
Returns:
Array with ELU applied element-wise
"""
def selu(x) -> Array:
"""
Scaled Exponential Linear Unit with fixed alpha and scale.
Args:
x: Input array
Returns:
Array with SELU applied element-wise
"""
def celu(x, alpha=1.0) -> Array:
"""
Continuously Differentiable Exponential Linear Unit.
Args:
x: Input array
alpha: Scale parameter (default: 1.0)
Returns:
Array with CELU applied element-wise
"""
```
### Modern Activations
Contemporary activation functions that have shown improved performance in various architectures.
```python { .api }
def gelu(x, approximate=True) -> Array:
"""
Gaussian Error Linear Unit: x * Φ(x) where Φ is CDF of standard normal.
Args:
x: Input array
approximate: Whether to use tanh approximation (default: True)
Returns:
Array with GELU applied element-wise
"""
def silu(x) -> Array:
"""
Sigmoid Linear Unit (Swish): x * sigmoid(x).
Args:
x: Input array
Returns:
Array with SiLU applied element-wise
"""
def swish(x) -> Array:
"""
Swish activation (alias for SiLU): x * sigmoid(x).
Args:
x: Input array
Returns:
Array with Swish applied element-wise
"""
def mish(x) -> Array:
"""
Mish activation: x * tanh(softplus(x)).
Args:
x: Input array
Returns:
Array with Mish applied element-wise
"""
def hard_silu(x) -> Array:
"""
Hard SiLU (Hard Swish variant): x * hard_sigmoid(x).
Args:
x: Input array
Returns:
Array with Hard SiLU applied element-wise
"""
def hard_swish(x) -> Array:
"""
Hard Swish: x * relu6(x + 3) / 6.
Args:
x: Input array
Returns:
Array with Hard Swish applied element-wise
"""
def squareplus(x, b=4.0) -> Array:
"""
Squareplus activation: (x + sqrt(x^2 + b)) / 2.
Args:
x: Input array
b: Shape parameter (default: 4.0)
Returns:
Array with Squareplus applied element-wise
"""
```
### Sigmoid and Tanh Variants
Sigmoid-based activations and their approximations for bounded outputs.
```python { .api }
def sigmoid(x) -> Array:
"""
Sigmoid activation: 1 / (1 + exp(-x)).
Args:
x: Input array
Returns:
Array with sigmoid applied element-wise
"""
def hard_sigmoid(x) -> Array:
"""
Hard sigmoid approximation: max(0, min(1, (x + 1) / 2)).
Args:
x: Input array
Returns:
Array with hard sigmoid applied element-wise
"""
def log_sigmoid(x) -> Array:
"""
Log sigmoid: log(sigmoid(x)) computed in numerically stable way.
Args:
x: Input array
Returns:
Array with log sigmoid applied element-wise
"""
def soft_sign(x) -> Array:
"""
Soft sign activation: x / (1 + |x|).
Args:
x: Input array
Returns:
Array with soft sign applied element-wise
"""
def tanh(x) -> Array:
"""
Hyperbolic tangent activation.
Args:
x: Input array
Returns:
Array with tanh applied element-wise
"""
def hard_tanh(x) -> Array:
"""
Hard tanh activation: max(-1, min(1, x)).
Args:
x: Input array
Returns:
Array with hard tanh applied element-wise
"""
```
### Softmax and Normalization
Normalization functions for probability distributions and feature standardization.
```python { .api }
def softmax(x, axis=-1, where=None, initial=None) -> Array:
"""
Softmax activation: exp(x_i) / sum(exp(x)) along axis.
Args:
x: Input array
axis: Axis to apply softmax along (default: -1)
where: Mask for conditional computation
initial: Initial value for reduction
Returns:
Array with softmax applied along specified axis
"""
def log_softmax(x, axis=-1, where=None, initial=None) -> Array:
"""
Log softmax: log(softmax(x)) computed in numerically stable way.
Args:
x: Input array
axis: Axis to apply log softmax along (default: -1)
where: Mask for conditional computation
initial: Initial value for reduction
Returns:
Array with log softmax applied along specified axis
"""
def softplus(x) -> Array:
"""
Softplus activation: log(1 + exp(x)).
Args:
x: Input array
Returns:
Array with softplus applied element-wise
"""
def standardize(x, axis=None, mean=None, variance=None, epsilon=1e-5) -> Array:
"""
Standardize array to zero mean and unit variance.
Args:
x: Input array to standardize
axis: Axis to compute statistics along
mean: Pre-computed mean (computed if None)
variance: Pre-computed variance (computed if None)
epsilon: Small value for numerical stability
Returns:
Standardized array
"""
def glu(x, axis=-1) -> Array:
"""
Gated Linear Unit: split x in half along axis, return a * sigmoid(b).
Args:
x: Input array (size along axis must be even)
axis: Axis to split along (default: -1)
Returns:
Array with GLU applied
"""
```
### Specialized Functions
Utility functions for neural network operations and transformations.
```python { .api }
def one_hot(x, num_classes, dtype=None, axis=-1) -> Array:
"""
One-hot encode array of integers.
Args:
x: Integer array to encode
num_classes: Number of classes
dtype: Output data type
axis: Axis to insert one-hot dimension
Returns:
One-hot encoded array
"""
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None) -> Array:
"""
Compute log(sum(exp(a))) in numerically stable way.
Args:
a: Input array
axis: Axis to sum along
b: Scaling factor array
keepdims: Whether to keep reduced dimensions
return_sign: Whether to return sign separately
where: Mask for conditional computation
Returns:
Log-sum-exp result
"""
def logmeanexp(a, axis=None, b=None, keepdims=False, where=None) -> Array:
"""
Compute log(mean(exp(a))) in numerically stable way.
Args:
a: Input array
axis: Axis to average along
b: Scaling factor array
keepdims: Whether to keep reduced dimensions
where: Mask for conditional computation
Returns:
Log-mean-exp result
"""
def log1mexp(x) -> Array:
"""
Compute log(1 - exp(x)) in numerically stable way.
Args:
x: Input array (should be <= 0)
Returns:
Array with log(1 - exp(x)) applied element-wise
"""
def sparse_plus(x, y) -> Array:
"""
Sparse-aware addition that handles missing values.
Args:
x: First input array
y: Second input array
Returns:
Element-wise addition result
"""
def sparse_sigmoid(x) -> Array:
"""
Sparse-aware sigmoid activation.
Args:
x: Input array
Returns:
Sigmoid activation with sparse support
"""
```
### Attention Mechanisms
Attention functions for transformer and neural attention models.
```python { .api }
def dot_product_attention(
query,
key,
value,
bias=None,
mask=None,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.0,
deterministic=False,
dtype=None,
precision=None
) -> Array:
"""
Dot-product attention mechanism.
Args:
query: Query array (..., length_q, depth_q)
key: Key array (..., length_kv, depth_q)
value: Value array (..., length_kv, depth_v)
bias: Optional attention bias
mask: Optional attention mask
broadcast_dropout: Whether to broadcast dropout
dropout_rng: Random key for dropout
dropout_rate: Dropout probability
deterministic: Whether to use deterministic mode
dtype: Output data type
precision: Computation precision
Returns:
Attention output array (..., length_q, depth_v)
"""
def scaled_dot_general(
lhs,
rhs,
dimension_numbers,
alpha=1.0,
precision=None,
preferred_element_type=None
) -> Array:
"""
Scaled general dot product for attention computations.
Args:
lhs: Left-hand side array
rhs: Right-hand side array
dimension_numbers: Contraction specification
alpha: Scaling factor
precision: Computation precision
preferred_element_type: Preferred output type
Returns:
Scaled dot product result
"""
def scaled_matmul(
a,
b,
alpha=1.0,
precision=None,
preferred_element_type=None
) -> Array:
"""
Scaled matrix multiplication: alpha * (a @ b).
Args:
a: First matrix
b: Second matrix
alpha: Scaling factor
precision: Computation precision
preferred_element_type: Preferred output type
Returns:
Scaled matrix multiplication result
"""
def get_scaled_dot_general_config() -> dict:
"""
Get configuration for scaled dot product attention.
Returns:
Configuration dictionary for attention operations
"""
```
### Utility Functions
Additional utilities for neural network operations.
```python { .api }
def identity(x) -> Array:
"""
Identity function that returns input unchanged.
Args:
x: Input array
Returns:
Input array unchanged
"""
```
## Neural Network Initializers
JAX provides weight initialization functions through `jax.nn.initializers`:
```python { .api }
import jax.nn.initializers as init
# Standard initializers
init.zeros(key, shape, dtype=jnp.float32) -> Array
init.ones(key, shape, dtype=jnp.float32) -> Array
init.constant(value, dtype=jnp.float32) -> Callable
# Random initializers
init.uniform(scale=1e-2, dtype=jnp.float32) -> Callable
init.normal(stddev=1e-2, dtype=jnp.float32) -> Callable
init.truncated_normal(stddev=1e-2, dtype=jnp.float32) -> Callable
# Variance scaling initializers
init.variance_scaling(scale, mode, distribution, dtype=jnp.float32) -> Callable
init.glorot_uniform(dtype=jnp.float32) -> Callable
init.glorot_normal(dtype=jnp.float32) -> Callable
init.lecun_uniform(dtype=jnp.float32) -> Callable
init.lecun_normal(dtype=jnp.float32) -> Callable
init.he_uniform(dtype=jnp.float32) -> Callable
init.he_normal(dtype=jnp.float32) -> Callable
# Orthogonal initializer
init.orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
# Delta orthogonal initializer (for RNNs)
init.delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
```
Usage examples:
```python
import jax
import jax.numpy as jnp
import jax.nn as jnn
from jax.nn import initializers as init
# Initialize weights
key = jax.random.key(42)
weights = init.glorot_uniform()(key, (784, 128))
biases = init.zeros(key, (128,))
# Apply activations in a simple neural network layer
def dense_layer(x, weights, biases):
return jnn.relu(x @ weights + biases)
# Multi-layer example with different activations
def mlp(x, params):
x = jnn.relu(x @ params['w1'] + params['b1'])
x = jnn.gelu(x @ params['w2'] + params['b2'])
x = jnn.softmax(x @ params['w3'] + params['b3'])
return x
# Attention example
def simple_attention(q, k, v):
# Scaled dot-product attention
scores = jnn.dot_product_attention(q, k, v)
return scores
```