573 lines
13 KiB
Markdown
573 lines
13 KiB
Markdown
|
|
# Neural Network Functions
|
||
|
|
|
||
|
|
JAX provides a comprehensive set of neural network functions through `jax.nn` including activation functions, normalization utilities, and attention mechanisms commonly used in machine learning and deep learning applications.
|
||
|
|
|
||
|
|
## Core Imports
|
||
|
|
|
||
|
|
```python
|
||
|
|
import jax.nn as jnn
|
||
|
|
from jax.nn import relu, sigmoid, softmax, gelu
|
||
|
|
```
|
||
|
|
|
||
|
|
## Capabilities
|
||
|
|
|
||
|
|
### ReLU and Variants
|
||
|
|
|
||
|
|
Rectified Linear Unit activations and their variants for introducing non-linearity while maintaining computational efficiency.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def relu(x) -> Array:
|
||
|
|
"""
|
||
|
|
Rectified Linear Unit activation: max(0, x).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with ReLU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def relu6(x) -> Array:
|
||
|
|
"""
|
||
|
|
ReLU capped at 6: min(max(0, x), 6).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with ReLU6 applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def leaky_relu(x, negative_slope=0.01) -> Array:
|
||
|
|
"""
|
||
|
|
Leaky ReLU: max(negative_slope * x, x).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
negative_slope: Slope for negative values (default: 0.01)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with Leaky ReLU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def elu(x, alpha=1.0) -> Array:
|
||
|
|
"""
|
||
|
|
Exponential Linear Unit: x if x > 0 else alpha * (exp(x) - 1).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
alpha: Scale for negative values (default: 1.0)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with ELU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def selu(x) -> Array:
|
||
|
|
"""
|
||
|
|
Scaled Exponential Linear Unit with fixed alpha and scale.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with SELU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def celu(x, alpha=1.0) -> Array:
|
||
|
|
"""
|
||
|
|
Continuously Differentiable Exponential Linear Unit.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
alpha: Scale parameter (default: 1.0)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with CELU applied element-wise
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
### Modern Activations
|
||
|
|
|
||
|
|
Contemporary activation functions that have shown improved performance in various architectures.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def gelu(x, approximate=True) -> Array:
|
||
|
|
"""
|
||
|
|
Gaussian Error Linear Unit: x * Φ(x) where Φ is CDF of standard normal.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
approximate: Whether to use tanh approximation (default: True)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with GELU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def silu(x) -> Array:
|
||
|
|
"""
|
||
|
|
Sigmoid Linear Unit (Swish): x * sigmoid(x).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with SiLU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def swish(x) -> Array:
|
||
|
|
"""
|
||
|
|
Swish activation (alias for SiLU): x * sigmoid(x).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with Swish applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def mish(x) -> Array:
|
||
|
|
"""
|
||
|
|
Mish activation: x * tanh(softplus(x)).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with Mish applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def hard_silu(x) -> Array:
|
||
|
|
"""
|
||
|
|
Hard SiLU (Hard Swish variant): x * hard_sigmoid(x).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with Hard SiLU applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def hard_swish(x) -> Array:
|
||
|
|
"""
|
||
|
|
Hard Swish: x * relu6(x + 3) / 6.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with Hard Swish applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def squareplus(x, b=4.0) -> Array:
|
||
|
|
"""
|
||
|
|
Squareplus activation: (x + sqrt(x^2 + b)) / 2.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
b: Shape parameter (default: 4.0)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with Squareplus applied element-wise
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
### Sigmoid and Tanh Variants
|
||
|
|
|
||
|
|
Sigmoid-based activations and their approximations for bounded outputs.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def sigmoid(x) -> Array:
|
||
|
|
"""
|
||
|
|
Sigmoid activation: 1 / (1 + exp(-x)).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with sigmoid applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def hard_sigmoid(x) -> Array:
|
||
|
|
"""
|
||
|
|
Hard sigmoid approximation: max(0, min(1, (x + 1) / 2)).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with hard sigmoid applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def log_sigmoid(x) -> Array:
|
||
|
|
"""
|
||
|
|
Log sigmoid: log(sigmoid(x)) computed in numerically stable way.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with log sigmoid applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def soft_sign(x) -> Array:
|
||
|
|
"""
|
||
|
|
Soft sign activation: x / (1 + |x|).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with soft sign applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def tanh(x) -> Array:
|
||
|
|
"""
|
||
|
|
Hyperbolic tangent activation.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with tanh applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def hard_tanh(x) -> Array:
|
||
|
|
"""
|
||
|
|
Hard tanh activation: max(-1, min(1, x)).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with hard tanh applied element-wise
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
### Softmax and Normalization
|
||
|
|
|
||
|
|
Normalization functions for probability distributions and feature standardization.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def softmax(x, axis=-1, where=None, initial=None) -> Array:
|
||
|
|
"""
|
||
|
|
Softmax activation: exp(x_i) / sum(exp(x)) along axis.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
axis: Axis to apply softmax along (default: -1)
|
||
|
|
where: Mask for conditional computation
|
||
|
|
initial: Initial value for reduction
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with softmax applied along specified axis
|
||
|
|
"""
|
||
|
|
|
||
|
|
def log_softmax(x, axis=-1, where=None, initial=None) -> Array:
|
||
|
|
"""
|
||
|
|
Log softmax: log(softmax(x)) computed in numerically stable way.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
axis: Axis to apply log softmax along (default: -1)
|
||
|
|
where: Mask for conditional computation
|
||
|
|
initial: Initial value for reduction
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with log softmax applied along specified axis
|
||
|
|
"""
|
||
|
|
|
||
|
|
def softplus(x) -> Array:
|
||
|
|
"""
|
||
|
|
Softplus activation: log(1 + exp(x)).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with softplus applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def standardize(x, axis=None, mean=None, variance=None, epsilon=1e-5) -> Array:
|
||
|
|
"""
|
||
|
|
Standardize array to zero mean and unit variance.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array to standardize
|
||
|
|
axis: Axis to compute statistics along
|
||
|
|
mean: Pre-computed mean (computed if None)
|
||
|
|
variance: Pre-computed variance (computed if None)
|
||
|
|
epsilon: Small value for numerical stability
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Standardized array
|
||
|
|
"""
|
||
|
|
|
||
|
|
def glu(x, axis=-1) -> Array:
|
||
|
|
"""
|
||
|
|
Gated Linear Unit: split x in half along axis, return a * sigmoid(b).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array (size along axis must be even)
|
||
|
|
axis: Axis to split along (default: -1)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with GLU applied
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
### Specialized Functions
|
||
|
|
|
||
|
|
Utility functions for neural network operations and transformations.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def one_hot(x, num_classes, dtype=None, axis=-1) -> Array:
|
||
|
|
"""
|
||
|
|
One-hot encode array of integers.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Integer array to encode
|
||
|
|
num_classes: Number of classes
|
||
|
|
dtype: Output data type
|
||
|
|
axis: Axis to insert one-hot dimension
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
One-hot encoded array
|
||
|
|
"""
|
||
|
|
|
||
|
|
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None) -> Array:
|
||
|
|
"""
|
||
|
|
Compute log(sum(exp(a))) in numerically stable way.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
a: Input array
|
||
|
|
axis: Axis to sum along
|
||
|
|
b: Scaling factor array
|
||
|
|
keepdims: Whether to keep reduced dimensions
|
||
|
|
return_sign: Whether to return sign separately
|
||
|
|
where: Mask for conditional computation
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Log-sum-exp result
|
||
|
|
"""
|
||
|
|
|
||
|
|
def logmeanexp(a, axis=None, b=None, keepdims=False, where=None) -> Array:
|
||
|
|
"""
|
||
|
|
Compute log(mean(exp(a))) in numerically stable way.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
a: Input array
|
||
|
|
axis: Axis to average along
|
||
|
|
b: Scaling factor array
|
||
|
|
keepdims: Whether to keep reduced dimensions
|
||
|
|
where: Mask for conditional computation
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Log-mean-exp result
|
||
|
|
"""
|
||
|
|
|
||
|
|
def log1mexp(x) -> Array:
|
||
|
|
"""
|
||
|
|
Compute log(1 - exp(x)) in numerically stable way.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array (should be <= 0)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Array with log(1 - exp(x)) applied element-wise
|
||
|
|
"""
|
||
|
|
|
||
|
|
def sparse_plus(x, y) -> Array:
|
||
|
|
"""
|
||
|
|
Sparse-aware addition that handles missing values.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: First input array
|
||
|
|
y: Second input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Element-wise addition result
|
||
|
|
"""
|
||
|
|
|
||
|
|
def sparse_sigmoid(x) -> Array:
|
||
|
|
"""
|
||
|
|
Sparse-aware sigmoid activation.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Sigmoid activation with sparse support
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
### Attention Mechanisms
|
||
|
|
|
||
|
|
Attention functions for transformer and neural attention models.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def dot_product_attention(
|
||
|
|
query,
|
||
|
|
key,
|
||
|
|
value,
|
||
|
|
bias=None,
|
||
|
|
mask=None,
|
||
|
|
broadcast_dropout=True,
|
||
|
|
dropout_rng=None,
|
||
|
|
dropout_rate=0.0,
|
||
|
|
deterministic=False,
|
||
|
|
dtype=None,
|
||
|
|
precision=None
|
||
|
|
) -> Array:
|
||
|
|
"""
|
||
|
|
Dot-product attention mechanism.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: Query array (..., length_q, depth_q)
|
||
|
|
key: Key array (..., length_kv, depth_q)
|
||
|
|
value: Value array (..., length_kv, depth_v)
|
||
|
|
bias: Optional attention bias
|
||
|
|
mask: Optional attention mask
|
||
|
|
broadcast_dropout: Whether to broadcast dropout
|
||
|
|
dropout_rng: Random key for dropout
|
||
|
|
dropout_rate: Dropout probability
|
||
|
|
deterministic: Whether to use deterministic mode
|
||
|
|
dtype: Output data type
|
||
|
|
precision: Computation precision
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Attention output array (..., length_q, depth_v)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def scaled_dot_general(
|
||
|
|
lhs,
|
||
|
|
rhs,
|
||
|
|
dimension_numbers,
|
||
|
|
alpha=1.0,
|
||
|
|
precision=None,
|
||
|
|
preferred_element_type=None
|
||
|
|
) -> Array:
|
||
|
|
"""
|
||
|
|
Scaled general dot product for attention computations.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
lhs: Left-hand side array
|
||
|
|
rhs: Right-hand side array
|
||
|
|
dimension_numbers: Contraction specification
|
||
|
|
alpha: Scaling factor
|
||
|
|
precision: Computation precision
|
||
|
|
preferred_element_type: Preferred output type
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Scaled dot product result
|
||
|
|
"""
|
||
|
|
|
||
|
|
def scaled_matmul(
|
||
|
|
a,
|
||
|
|
b,
|
||
|
|
alpha=1.0,
|
||
|
|
precision=None,
|
||
|
|
preferred_element_type=None
|
||
|
|
) -> Array:
|
||
|
|
"""
|
||
|
|
Scaled matrix multiplication: alpha * (a @ b).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
a: First matrix
|
||
|
|
b: Second matrix
|
||
|
|
alpha: Scaling factor
|
||
|
|
precision: Computation precision
|
||
|
|
preferred_element_type: Preferred output type
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Scaled matrix multiplication result
|
||
|
|
"""
|
||
|
|
|
||
|
|
def get_scaled_dot_general_config() -> dict:
|
||
|
|
"""
|
||
|
|
Get configuration for scaled dot product attention.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Configuration dictionary for attention operations
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
### Utility Functions
|
||
|
|
|
||
|
|
Additional utilities for neural network operations.
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
def identity(x) -> Array:
|
||
|
|
"""
|
||
|
|
Identity function that returns input unchanged.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
x: Input array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Input array unchanged
|
||
|
|
"""
|
||
|
|
```
|
||
|
|
|
||
|
|
## Neural Network Initializers
|
||
|
|
|
||
|
|
JAX provides weight initialization functions through `jax.nn.initializers`:
|
||
|
|
|
||
|
|
```python { .api }
|
||
|
|
import jax.nn.initializers as init
|
||
|
|
|
||
|
|
# Standard initializers
|
||
|
|
init.zeros(key, shape, dtype=jnp.float32) -> Array
|
||
|
|
init.ones(key, shape, dtype=jnp.float32) -> Array
|
||
|
|
init.constant(value, dtype=jnp.float32) -> Callable
|
||
|
|
|
||
|
|
# Random initializers
|
||
|
|
init.uniform(scale=1e-2, dtype=jnp.float32) -> Callable
|
||
|
|
init.normal(stddev=1e-2, dtype=jnp.float32) -> Callable
|
||
|
|
init.truncated_normal(stddev=1e-2, dtype=jnp.float32) -> Callable
|
||
|
|
|
||
|
|
# Variance scaling initializers
|
||
|
|
init.variance_scaling(scale, mode, distribution, dtype=jnp.float32) -> Callable
|
||
|
|
init.glorot_uniform(dtype=jnp.float32) -> Callable
|
||
|
|
init.glorot_normal(dtype=jnp.float32) -> Callable
|
||
|
|
init.lecun_uniform(dtype=jnp.float32) -> Callable
|
||
|
|
init.lecun_normal(dtype=jnp.float32) -> Callable
|
||
|
|
init.he_uniform(dtype=jnp.float32) -> Callable
|
||
|
|
init.he_normal(dtype=jnp.float32) -> Callable
|
||
|
|
|
||
|
|
# Orthogonal initializer
|
||
|
|
init.orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
|
||
|
|
|
||
|
|
# Delta orthogonal initializer (for RNNs)
|
||
|
|
init.delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
|
||
|
|
```
|
||
|
|
|
||
|
|
Usage examples:
|
||
|
|
|
||
|
|
```python
|
||
|
|
import jax
|
||
|
|
import jax.numpy as jnp
|
||
|
|
import jax.nn as jnn
|
||
|
|
from jax.nn import initializers as init
|
||
|
|
|
||
|
|
# Initialize weights
|
||
|
|
key = jax.random.key(42)
|
||
|
|
weights = init.glorot_uniform()(key, (784, 128))
|
||
|
|
biases = init.zeros(key, (128,))
|
||
|
|
|
||
|
|
# Apply activations in a simple neural network layer
|
||
|
|
def dense_layer(x, weights, biases):
|
||
|
|
return jnn.relu(x @ weights + biases)
|
||
|
|
|
||
|
|
# Multi-layer example with different activations
|
||
|
|
def mlp(x, params):
|
||
|
|
x = jnn.relu(x @ params['w1'] + params['b1'])
|
||
|
|
x = jnn.gelu(x @ params['w2'] + params['b2'])
|
||
|
|
x = jnn.softmax(x @ params['w3'] + params['b3'])
|
||
|
|
return x
|
||
|
|
|
||
|
|
# Attention example
|
||
|
|
def simple_attention(q, k, v):
|
||
|
|
# Scaled dot-product attention
|
||
|
|
scores = jnn.dot_product_attention(q, k, v)
|
||
|
|
return scores
|
||
|
|
```
|