mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
23 KiB
23 KiB
SciPy Compatibility
JAX provides SciPy-compatible functions through jax.scipy for scientific computing including linear algebra, signal processing, special functions, statistics, and sparse operations. These functions are differentiable and can be JIT-compiled.
Core Imports
import jax.scipy as jsp
import jax.scipy.linalg as jla
import jax.scipy.special as jss
import jax.scipy.stats as jst
Capabilities
Linear Algebra (jax.scipy.linalg)
Advanced linear algebra operations for matrix computations and decompositions.
# Matrix decompositions
def cholesky(a, lower=True) -> Array:
"""
Cholesky decomposition of positive definite matrix.
Args:
a: Positive definite matrix to decompose
lower: Whether to return lower triangular factor
Returns:
Cholesky factor L such that a = L @ L.T (or U.T @ U if upper)
"""
def qr(a, mode='reduced') -> tuple[Array, Array]:
"""
QR decomposition of matrix.
Args:
a: Matrix to decompose
mode: 'reduced' or 'complete' decomposition
Returns:
Tuple (Q, R) where Q is orthogonal and R is upper triangular
"""
def svd(a, full_matrices=True, compute_uv=True, hermitian=False) -> tuple[Array, Array, Array]:
"""
Singular Value Decomposition.
Args:
a: Matrix to decompose
full_matrices: Whether to compute full or reduced SVD
compute_uv: Whether to compute U and V matrices
hermitian: Whether matrix is Hermitian
Returns:
Tuple (U, s, Vh) where a = U @ diag(s) @ Vh
"""
def eig(a, b=None, left=False, right=True, overwrite_a=False, overwrite_b=False,
check_finite=True, homogeneous_eigvals=False) -> tuple[Array, Array]:
"""
Eigenvalues and eigenvectors of general matrix.
Args:
a: Square matrix
b: Optional matrix for generalized eigenvalue problem
left: Whether to compute left eigenvectors
right: Whether to compute right eigenvectors
overwrite_a: Whether input can be overwritten
overwrite_b: Whether b can be overwritten
check_finite: Whether to check for finite values
homogeneous_eigvals: Whether to return homogeneous eigenvalues
Returns:
Tuple (eigenvalues, eigenvectors)
"""
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
overwrite_b=False, turbo=True, eigvals=None, type=1,
check_finite=True) -> tuple[Array, Array]:
"""
Eigenvalues and eigenvectors of Hermitian matrix.
Args:
a: Hermitian matrix
b: Optional matrix for generalized problem
lower: Whether to use lower triangle
eigvals_only: Whether to compute eigenvalues only
overwrite_a: Whether input can be overwritten
overwrite_b: Whether b can be overwritten
turbo: Whether to use turbo algorithm
eigvals: Range of eigenvalue indices to compute
type: Type of generalized eigenvalue problem
check_finite: Whether to check for finite values
Returns:
Eigenvalues (and eigenvectors if eigvals_only=False)
"""
def eigvals(a, b=None, overwrite_a=False, check_finite=True,
homogeneous_eigvals=False) -> Array:
"""Eigenvalues of general matrix."""
def eigvalsh(a, b=None, lower=True, overwrite_a=False, overwrite_b=False,
turbo=True, eigvals=None, type=1, check_finite=True) -> Array:
"""Eigenvalues of Hermitian matrix."""
# Matrix properties and functions
def det(a) -> Array:
"""Matrix determinant."""
def slogdet(a) -> tuple[Array, Array]:
"""Sign and log determinant of matrix."""
def logdet(a) -> Array:
"""Log determinant of matrix."""
def matrix_rank(M, tol=None, hermitian=False) -> Array:
"""Matrix rank computation."""
def trace(a, offset=0, axis1=0, axis2=1) -> Array:
"""Matrix trace."""
def norm(a, ord=None, axis=None, keepdims=False) -> Array:
"""Matrix or vector norm."""
def cond(x, p=None) -> Array:
"""Condition number of matrix."""
# Matrix solutions
def solve(a, b, assume_a='gen', lower=False, overwrite_a=False,
overwrite_b=False, debug=None, check_finite=True) -> Array:
"""
Solve linear system Ax = b.
Args:
a: Coefficient matrix
b: Right-hand side vector/matrix
assume_a: Properties of matrix a ('gen', 'sym', 'her', 'pos')
lower: Whether to use lower triangle for triangular matrices
overwrite_a: Whether input can be overwritten
overwrite_b: Whether b can be overwritten
debug: Debug information level
check_finite: Whether to check for finite values
Returns:
Solution x such that Ax = b
"""
def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
overwrite_b=False, debug=None, check_finite=True) -> Array:
"""Solve triangular linear system."""
def inv(a, overwrite_a=False, check_finite=True) -> Array:
"""Matrix inverse."""
def pinv(a, rcond=None, hermitian=False, return_rank=False) -> Array:
"""Moore-Penrose pseudoinverse."""
def lstsq(a, b, rcond=None, lapack_driver=None) -> tuple[Array, Array, Array, Array]:
"""
Least-squares solution to linear system.
Args:
a: Coefficient matrix
b: Dependent variable values
rcond: Cutoff ratio for small singular values
lapack_driver: LAPACK driver to use
Returns:
Tuple (solution, residuals, rank, singular_values)
"""
# Matrix functions
def expm(A) -> Array:
"""Matrix exponential."""
def funm(A, func, disp=True) -> Array:
"""General matrix function evaluation."""
def sqrtm(A, disp=True, blocksize=64) -> Array:
"""Matrix square root."""
def logm(A, disp=True) -> Array:
"""Matrix logarithm."""
def fractional_matrix_power(A, t) -> Array:
"""Fractional matrix power A^t."""
def matrix_power(A, n) -> Array:
"""Integer matrix power A^n."""
# Schur decomposition
def schur(a, output='real') -> tuple[Array, Array]:
"""Schur decomposition of matrix."""
def rsf2csf(T, Z) -> tuple[Array, Array]:
"""Convert real Schur form to complex Schur form."""
# Polar decomposition
def polar(a, side='right') -> tuple[Array, Array]:
"""Polar decomposition of matrix."""
Special Functions (jax.scipy.special)
Special mathematical functions including error functions, gamma functions, and Bessel functions.
# Error functions
def erf(z) -> Array:
"""Error function."""
def erfc(x) -> Array:
"""Complementary error function."""
def erfinv(y) -> Array:
"""Inverse error function."""
def erfcinv(y) -> Array:
"""Inverse complementary error function."""
def wofz(z) -> Array:
"""Faddeeva function."""
# Gamma functions
def gamma(z) -> Array:
"""Gamma function."""
def gammaln(x) -> Array:
"""Log gamma function."""
def digamma(x) -> Array:
"""Digamma (psi) function."""
def polygamma(n, x) -> Array:
"""Polygamma function."""
def gammainc(a, x) -> Array:
"""Lower incomplete gamma function."""
def gammaincc(a, x) -> Array:
"""Upper incomplete gamma function."""
def gammasgn(x) -> Array:
"""Sign of gamma function."""
def rgamma(x) -> Array:
"""Reciprocal gamma function."""
# Beta functions
def beta(a, b) -> Array:
"""Beta function."""
def betaln(a, b) -> Array:
"""Log beta function."""
def betainc(a, b, x) -> Array:
"""Incomplete beta function."""
# Bessel functions
def j0(x) -> Array:
"""Bessel function of the first kind of order 0."""
def j1(x) -> Array:
"""Bessel function of the first kind of order 1."""
def jn(n, x) -> Array:
"""Bessel function of the first kind of order n."""
def y0(x) -> Array:
"""Bessel function of the second kind of order 0."""
def y1(x) -> Array:
"""Bessel function of the second kind of order 1."""
def yn(n, x) -> Array:
"""Bessel function of the second kind of order n."""
def i0(x) -> Array:
"""Modified Bessel function of the first kind of order 0."""
def i0e(x) -> Array:
"""Exponentially scaled modified Bessel function i0."""
def i1(x) -> Array:
"""Modified Bessel function of the first kind of order 1."""
def i1e(x) -> Array:
"""Exponentially scaled modified Bessel function i1."""
def iv(v, z) -> Array:
"""Modified Bessel function of the first kind of real order."""
def k0(x) -> Array:
"""Modified Bessel function of the second kind of order 0."""
def k0e(x) -> Array:
"""Exponentially scaled modified Bessel function k0."""
def k1(x) -> Array:
"""Modified Bessel function of the second kind of order 1."""
def k1e(x) -> Array:
"""Exponentially scaled modified Bessel function k1."""
def kv(v, z) -> Array:
"""Modified Bessel function of the second kind of real order."""
# Exponential integrals
def expi(x) -> Array:
"""Exponential integral Ei."""
def expn(n, x) -> Array:
"""Generalized exponential integral."""
# Log-sum-exp and related
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False) -> Array:
"""
Compute log(sum(exp(a))) in numerically stable way.
Args:
a: Input array
axis: Axis to sum over
b: Multiplier for each element
keepdims: Whether to keep reduced dimensions
return_sign: Whether to return sign separately
Returns:
Log-sum-exp result
"""
def softmax(x, axis=None) -> Array:
"""Softmax function."""
def log_softmax(x, axis=None) -> Array:
"""Log softmax function."""
# Combinatorial functions
def factorial(n, exact=False) -> Array:
"""Factorial function."""
def factorial2(n, exact=False) -> Array:
"""Double factorial function."""
def factorialk(n, k, exact=False) -> Array:
"""Multifactorial function."""
def comb(N, k, exact=False, repetition=False) -> Array:
"""Binomial coefficient."""
def perm(N, k, exact=False) -> Array:
"""Permutation coefficient."""
# Elliptic integrals
def ellipk(m) -> Array:
"""Complete elliptic integral of the first kind."""
def ellipe(m) -> Array:
"""Complete elliptic integral of the second kind."""
def ellipkinc(phi, m) -> Array:
"""Incomplete elliptic integral of the first kind."""
def ellipeinc(phi, m) -> Array:
"""Incomplete elliptic integral of the second kind."""
# Zeta and related functions
def zeta(x, q=None) -> Array:
"""Riemann or Hurwitz zeta function."""
def zetac(x) -> Array:
"""Riemann zeta function minus 1."""
# Hypergeometric functions
def hyp1f1(a, b, x) -> Array:
"""Confluent hypergeometric function 1F1."""
def hyp2f1(a, b, c, z) -> Array:
"""Gaussian hypergeometric function 2F1."""
def hyperu(a, b, x) -> Array:
"""Confluent hypergeometric function U."""
# Legendre functions
def legendre(n, x) -> Array:
"""Legendre polynomial."""
def lpmv(m, v, x) -> Array:
"""Associated Legendre function."""
# Spherical functions
def sph_harm(m, n, theta, phi) -> Array:
"""Spherical harmonics."""
# Other special functions
def lambertw(z, k=0, tol=1e-8) -> Array:
"""Lambert W function."""
def spence(z) -> Array:
"""Spence function."""
def multigammaln(a, d) -> Array:
"""Log of multivariate gamma function."""
def entr(x) -> Array:
"""Elementwise function -x*log(x)."""
def kl_div(x, y) -> Array:
"""Elementwise function x*log(x/y) - x + y."""
def rel_entr(x, y) -> Array:
"""Elementwise function x*log(x/y)."""
def huber(delta, r) -> Array:
"""Huber loss function."""
def pseudo_huber(delta, r) -> Array:
"""Pseudo-Huber loss function."""
Statistics (jax.scipy.stats)
Statistical distributions and functions for probability and hypothesis testing.
# Continuous distributions
class norm:
"""Normal distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logcdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def sf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logsf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def ppf(q, loc=0, scale=1) -> Array: ...
@staticmethod
def isf(q, loc=0, scale=1) -> Array: ...
class multivariate_normal:
"""Multivariate normal distribution."""
@staticmethod
def pdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
@staticmethod
def logpdf(x, mean=None, cov=1, allow_singular=False) -> Array: ...
class uniform:
"""Uniform distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logcdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def sf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logsf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def ppf(q, loc=0, scale=1) -> Array: ...
class beta:
"""Beta distribution."""
@staticmethod
def pdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, a, b, loc=0, scale=1) -> Array: ...
class gamma:
"""Gamma distribution."""
@staticmethod
def pdf(x, a, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, a, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, a, loc=0, scale=1) -> Array: ...
class chi2:
"""Chi-square distribution."""
@staticmethod
def pdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, df, loc=0, scale=1) -> Array: ...
class t:
"""Student's t-distribution."""
@staticmethod
def pdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, df, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, df, loc=0, scale=1) -> Array: ...
class f:
"""F-distribution."""
@staticmethod
def pdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, dfn, dfd, loc=0, scale=1) -> Array: ...
class laplace:
"""Laplace distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
class logistic:
"""Logistic distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
class pareto:
"""Pareto distribution."""
@staticmethod
def pdf(x, b, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, b, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, b, loc=0, scale=1) -> Array: ...
class expon:
"""Exponential distribution."""
@staticmethod
def pdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, loc=0, scale=1) -> Array: ...
class lognorm:
"""Log-normal distribution."""
@staticmethod
def pdf(x, s, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, s, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, s, loc=0, scale=1) -> Array: ...
class truncnorm:
"""Truncated normal distribution."""
@staticmethod
def pdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def logpdf(x, a, b, loc=0, scale=1) -> Array: ...
@staticmethod
def cdf(x, a, b, loc=0, scale=1) -> Array: ...
# Discrete distributions
class bernoulli:
"""Bernoulli distribution."""
@staticmethod
def pmf(k, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, p, loc=0) -> Array: ...
class binom:
"""Binomial distribution."""
@staticmethod
def pmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, n, p, loc=0) -> Array: ...
class geom:
"""Geometric distribution."""
@staticmethod
def pmf(k, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, p, loc=0) -> Array: ...
class nbinom:
"""Negative binomial distribution."""
@staticmethod
def pmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def logpmf(k, n, p, loc=0) -> Array: ...
@staticmethod
def cdf(k, n, p, loc=0) -> Array: ...
class poisson:
"""Poisson distribution."""
@staticmethod
def pmf(k, mu, loc=0) -> Array: ...
@staticmethod
def logpmf(k, mu, loc=0) -> Array: ...
@staticmethod
def cdf(k, mu, loc=0) -> Array: ...
# Statistical functions
def mode(a, axis=0, nan_policy='propagate', keepdims=False) -> Array:
"""Mode of array values along axis."""
def rankdata(a, method='average', axis=None) -> Array:
"""Rank data along axis."""
def kendalltau(x, y, initial_lexsort=None, nan_policy='propagate', method='auto') -> tuple[Array, Array]:
"""Kendall's tau correlation coefficient."""
def pearsonr(x, y) -> tuple[Array, Array]:
"""Pearson correlation coefficient."""
def spearmanr(a, b=None, axis=0, nan_policy='propagate', alternative='two-sided') -> tuple[Array, Array]:
"""Spearman correlation coefficient."""
Signal Processing (jax.scipy.signal)
Signal processing functions for filtering, convolution, and spectral analysis.
def convolve(in1, in2, mode='full', method='auto') -> Array:
"""N-dimensional convolution."""
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
"""2D convolution."""
def correlate(in1, in2, mode='full', method='auto') -> Array:
"""Cross-correlation of two arrays."""
def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0) -> Array:
"""2D cross-correlation."""
def fftconvolve(in1, in2, mode='full', axes=None) -> Array:
"""FFT-based convolution."""
def oaconvolve(in1, in2, mode='full', axes=None) -> Array:
"""Overlap-add convolution."""
def lfilter(b, a, x, axis=-1, zi=None) -> Array:
"""Linear digital filter."""
def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None) -> Array:
"""Zero-phase digital filtering."""
def sosfilt(sos, x, axis=-1, zi=None) -> Array:
"""Filter using second-order sections."""
def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None) -> Array:
"""Zero-phase filtering with second-order sections."""
def hilbert(x, N=None, axis=-1) -> Array:
"""Hilbert transform."""
def hilbert2(x, N=None) -> Array:
"""2D Hilbert transform."""
def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True) -> Array:
"""Downsample signal by integer factor."""
def resample(x, num, t=None, axis=0, window=None, domain='time') -> Array:
"""Resample signal to new sample rate."""
def resample_poly(x, up, down, axis=0, window='kaiser', padtype='constant', cval=None) -> Array:
"""Resample using polyphase filtering."""
def upfirdn(h, x, up=1, down=1, axis=-1, mode='constant', cval=0) -> Array:
"""Upsample, FIR filter, and downsample."""
def periodogram(x, fs=1.0, window='boxcar', nfft=None, detrend='constant',
return_onesided=True, scaling='density', axis=-1) -> tuple[Array, Array]:
"""Periodogram power spectral density."""
def welch(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density', axis=-1,
average='mean') -> tuple[Array, Array]:
"""Welch's method for power spectral density."""
def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density', axis=-1,
average='mean') -> tuple[Array, Array]:
"""Cross power spectral density."""
def coherence(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', axis=-1) -> tuple[Array, Array]:
"""Coherence between signals."""
def spectrogram(x, fs=1.0, window='tukey', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density', axis=-1,
mode='psd') -> tuple[Array, Array, Array]:
"""Spectrogram using short-time Fourier transform."""
def stft(x, fs=1.0, window='hann', nperseg=256, noverlap=None, nfft=None,
detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1) -> tuple[Array, Array, Array]:
"""Short-time Fourier transform."""
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2) -> tuple[Array, Array]:
"""Inverse short-time Fourier transform."""
def lombscargle(x, y, freqs, precenter=False, normalize=False) -> Array:
"""Lomb-Scargle periodogram."""
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False) -> Array:
"""Remove linear trend from data."""
def find_peaks(x, height=None, threshold=None, distance=None, prominence=None,
width=None, wlen=None, rel_height=0.5, plateau_size=None) -> tuple[Array, dict]:
"""Find peaks in 1D array."""
def peak_prominences(x, peaks, wlen=None) -> tuple[Array, Array, Array]:
"""Calculate peak prominences."""
def peak_widths(x, peaks, rel_height=0.5, prominence_data=None, wlen=None) -> tuple[Array, Array, Array, Array]:
"""Calculate peak widths."""
Other Submodules
# Fast Fourier Transform (jax.scipy.fft)
import jax.scipy.fft as jfft
# Same interface as jax.numpy.fft with additional functions
# N-dimensional image processing (jax.scipy.ndimage)
import jax.scipy.ndimage as jnd
# Image filtering, morphology, and measurements
# Sparse matrix operations (jax.scipy.sparse)
import jax.scipy.sparse as jss
# Sparse matrix formats and operations
# Interpolation (jax.scipy.interpolate)
import jax.scipy.interpolate as jsi
# 1D and multidimensional interpolation
# Clustering (jax.scipy.cluster)
import jax.scipy.cluster as jsc
# Hierarchical and k-means clustering
# Integration and ODE solving (jax.scipy.integrate)
import jax.scipy.integrate as jsi
# Numerical integration and differential equation solving
Usage Examples
import jax.numpy as jnp
import jax.scipy as jsp
import jax.scipy.linalg as jla
import jax.scipy.special as jss
import jax.scipy.stats as jst
# Linear algebra example
A = jnp.array([[4.0, 2.0], [2.0, 3.0]])
b = jnp.array([1.0, 2.0])
# Solve linear system
x = jla.solve(A, b)
# Compute eigenvalues and eigenvectors
eigenvals, eigenvecs = jla.eigh(A)
# Matrix decomposition
L = jla.cholesky(A) # A = L @ L.T
# Special functions
x = jnp.linspace(-3, 3, 100)
erf_vals = jss.erf(x)
gamma_vals = jss.gamma(x + 1)
# Statistical distributions
data = jnp.array([1.2, 2.3, 1.8, 3.1, 2.7])
log_likelihood = jst.norm.logpdf(data, loc=2.0, scale=1.0).sum()
# Probability density functions
x_vals = jnp.linspace(0, 5, 100)
pdf_vals = jst.gamma.pdf(x_vals, a=2.0, scale=1.0)
# Use in optimization with JAX transformations
@jax.jit
def neg_log_likelihood(params, data):
mu, sigma = params
return -jst.norm.logpdf(data, mu, sigma).sum()
# Compute gradient for maximum likelihood estimation
grad_fn = jax.grad(neg_log_likelihood)
gradients = grad_fn([2.0, 1.0], data)