mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
456 lines
12 KiB
Python
456 lines
12 KiB
Python
from functools import partial
|
|
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import torch
|
|
from jax import lax
|
|
|
|
|
|
def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
|
|
n = len(b)
|
|
|
|
# Create working copies to avoid modifying input
|
|
c_prime = np.zeros(n - 1, dtype=np.float64)
|
|
d_prime = np.zeros(n, dtype=np.float64)
|
|
x = np.zeros(n, dtype=np.float64)
|
|
|
|
# Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1]
|
|
c_prime[0] = c[0] / b[0]
|
|
d_prime[0] = d[0] / b[0]
|
|
|
|
for i in range(1, n - 1):
|
|
denom = b[i] - a[i - 1] * c_prime[i - 1]
|
|
c_prime[i] = c[i] / denom
|
|
d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom
|
|
|
|
# Last row of forward sweep
|
|
denom = b[n - 1] - a[n - 2] * c_prime[n - 2]
|
|
d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom
|
|
|
|
# Back substitution - sequential dependency: x[i] depends on x[i+1]
|
|
x[n - 1] = d_prime[n - 1]
|
|
for i in range(n - 2, -1, -1):
|
|
x[i] = d_prime[i] - c_prime[i] * x[i + 1]
|
|
|
|
return x
|
|
|
|
|
|
def leapfrog_integration(
|
|
positions: np.ndarray,
|
|
velocities: np.ndarray,
|
|
masses: np.ndarray,
|
|
dt: float,
|
|
n_steps: int,
|
|
softening: float = 0.01
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
n_particles = len(masses)
|
|
pos = positions.copy()
|
|
vel = velocities.copy()
|
|
acc = np.zeros_like(pos)
|
|
|
|
G = 1.0
|
|
|
|
for step in range(n_steps):
|
|
acc.fill(0.0)
|
|
|
|
for i in range(n_particles):
|
|
for j in range(i + 1, n_particles):
|
|
dx = pos[j, 0] - pos[i, 0]
|
|
dy = pos[j, 1] - pos[i, 1]
|
|
dz = pos[j, 2] - pos[i, 2]
|
|
|
|
dist_sq = dx * dx + dy * dy + dz * dz + softening * softening
|
|
dist = np.sqrt(dist_sq)
|
|
dist_cubed = dist_sq * dist
|
|
|
|
force_over_dist = G / dist_cubed
|
|
|
|
acc[i, 0] += masses[j] * force_over_dist * dx
|
|
acc[i, 1] += masses[j] * force_over_dist * dy
|
|
acc[i, 2] += masses[j] * force_over_dist * dz
|
|
|
|
acc[j, 0] -= masses[i] * force_over_dist * dx
|
|
acc[j, 1] -= masses[i] * force_over_dist * dy
|
|
acc[j, 2] -= masses[i] * force_over_dist * dz
|
|
|
|
for i in range(n_particles):
|
|
vel[i, 0] += 0.5 * dt * acc[i, 0]
|
|
vel[i, 1] += 0.5 * dt * acc[i, 1]
|
|
vel[i, 2] += 0.5 * dt * acc[i, 2]
|
|
|
|
for i in range(n_particles):
|
|
pos[i, 0] += dt * vel[i, 0]
|
|
pos[i, 1] += dt * vel[i, 1]
|
|
pos[i, 2] += dt * vel[i, 2]
|
|
|
|
for i in range(n_particles):
|
|
vel[i, 0] += 0.5 * dt * acc[i, 0]
|
|
vel[i, 1] += 0.5 * dt * acc[i, 1]
|
|
vel[i, 2] += 0.5 * dt * acc[i, 2]
|
|
|
|
return pos, vel
|
|
|
|
|
|
def longest_increasing_subsequence_length(arr: np.ndarray) -> int:
|
|
n = len(arr)
|
|
if n == 0:
|
|
return 0
|
|
|
|
dp = np.ones(n, dtype=np.int64)
|
|
|
|
for i in range(1, n):
|
|
for j in range(i):
|
|
if arr[j] < arr[i]:
|
|
if dp[j] + 1 > dp[i]:
|
|
dp[i] = dp[j] + 1
|
|
|
|
max_length = dp[0]
|
|
for i in range(1, n):
|
|
if dp[i] > max_length:
|
|
max_length = dp[i]
|
|
|
|
return max_length
|
|
|
|
|
|
def _tridiagonal_forward_step_jax(carry, inputs):
|
|
c_prev, d_prev = carry
|
|
a_i, b_i, c_i, d_i = inputs
|
|
denom = b_i - a_i * c_prev
|
|
c_new = c_i / denom
|
|
d_new = (d_i - a_i * d_prev) / denom
|
|
return (c_new, d_new), (c_new, d_new)
|
|
|
|
|
|
def _tridiagonal_back_step_jax(x_next, inputs):
|
|
d_prime_i, c_prime_i = inputs
|
|
x_i = d_prime_i - c_prime_i * x_next
|
|
return x_i, x_i
|
|
|
|
|
|
def tridiagonal_solve_jax(a, b, c, d):
|
|
n = b.shape[0]
|
|
|
|
c_prime_0 = c[0] / b[0]
|
|
d_prime_0 = d[0] / b[0]
|
|
|
|
scan_inputs = (a[:-1], b[1:-1], c[1:], d[1:-1])
|
|
|
|
_, (c_prime_rest, d_prime_mid) = lax.scan(
|
|
_tridiagonal_forward_step_jax,
|
|
(c_prime_0, d_prime_0),
|
|
scan_inputs
|
|
)
|
|
|
|
c_prime = jnp.concatenate([jnp.array([c_prime_0]), c_prime_rest])
|
|
|
|
denom_last = b[n - 1] - a[n - 2] * c_prime[n - 2]
|
|
d_prime_last = (d[n - 1] - a[n - 2] * d_prime_mid[-1]) / denom_last
|
|
d_prime = jnp.concatenate([jnp.array([d_prime_0]), d_prime_mid, jnp.array([d_prime_last])])
|
|
|
|
x_last = d_prime[n - 1]
|
|
_, x_rest = lax.scan(
|
|
_tridiagonal_back_step_jax,
|
|
x_last,
|
|
(d_prime[:-1], c_prime),
|
|
reverse=True
|
|
)
|
|
|
|
x = jnp.concatenate([x_rest, jnp.array([x_last])])
|
|
return x
|
|
|
|
|
|
def _leapfrog_compute_accelerations_jax(pos, masses, softening):
|
|
G = 1.0
|
|
diff = pos[jnp.newaxis, :, :] - pos[:, jnp.newaxis, :]
|
|
|
|
dist_sq = jnp.sum(diff ** 2, axis=-1) + softening ** 2
|
|
dist = jnp.sqrt(dist_sq)
|
|
dist_cubed = dist_sq * dist
|
|
|
|
dist_cubed = jnp.where(dist_cubed == 0, 1.0, dist_cubed)
|
|
|
|
force_factor = G * masses[jnp.newaxis, :] / dist_cubed
|
|
|
|
acc = jnp.sum(force_factor[:, :, jnp.newaxis] * diff, axis=1)
|
|
return acc
|
|
|
|
|
|
def _leapfrog_step_jax(carry, _, masses, softening, dt):
|
|
pos, vel = carry
|
|
acc = _leapfrog_compute_accelerations_jax(pos, masses, softening)
|
|
|
|
vel = vel + 0.5 * dt * acc
|
|
pos = pos + dt * vel
|
|
vel = vel + 0.5 * dt * acc
|
|
|
|
return (pos, vel), None
|
|
|
|
|
|
def leapfrog_integration_jax(
|
|
positions,
|
|
velocities,
|
|
masses,
|
|
dt: float,
|
|
n_steps: int,
|
|
softening: float = 0.01
|
|
):
|
|
step_fn = partial(_leapfrog_step_jax, masses=masses, softening=softening, dt=dt)
|
|
(final_pos, final_vel), _ = lax.scan(step_fn, (positions, velocities), None, length=n_steps)
|
|
return final_pos, final_vel
|
|
|
|
|
|
def _lis_inner_body_jax(j, dp_inner, arr, i):
|
|
condition = (arr[j] < arr[i]) & (dp_inner[j] + 1 > dp_inner[i])
|
|
new_val = jnp.where(condition, dp_inner[j] + 1, dp_inner[i])
|
|
return dp_inner.at[i].set(new_val)
|
|
|
|
|
|
def _lis_outer_body_jax(i, dp, arr):
|
|
inner_fn = partial(_lis_inner_body_jax, arr=arr, i=i)
|
|
dp = lax.fori_loop(0, i, inner_fn, dp)
|
|
return dp
|
|
|
|
|
|
def longest_increasing_subsequence_length_jax(arr):
|
|
n = arr.shape[0]
|
|
|
|
if n == 0:
|
|
return 0
|
|
|
|
outer_fn = partial(_lis_outer_body_jax, arr=arr)
|
|
dp = jnp.ones(n, dtype=jnp.int32)
|
|
dp = lax.fori_loop(1, n, outer_fn, dp)
|
|
|
|
return int(jnp.max(dp))
|
|
|
|
|
|
def tridiagonal_solve_torch(a, b, c, d):
|
|
device = b.device
|
|
dtype = b.dtype
|
|
n = b.shape[0]
|
|
|
|
c_prime = torch.zeros(n - 1, device=device, dtype=dtype)
|
|
d_prime = torch.zeros(n, device=device, dtype=dtype)
|
|
x = torch.zeros(n, device=device, dtype=dtype)
|
|
|
|
c_prime[0] = c[0] / b[0]
|
|
d_prime[0] = d[0] / b[0]
|
|
|
|
for i in range(1, n - 1):
|
|
denom = b[i] - a[i - 1] * c_prime[i - 1]
|
|
c_prime[i] = c[i] / denom
|
|
d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom
|
|
|
|
denom = b[n - 1] - a[n - 2] * c_prime[n - 2]
|
|
d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom
|
|
|
|
x[n - 1] = d_prime[n - 1]
|
|
for i in range(n - 2, -1, -1):
|
|
x[i] = d_prime[i] - c_prime[i] * x[i + 1]
|
|
|
|
return x
|
|
|
|
|
|
def leapfrog_integration_torch(
|
|
positions,
|
|
velocities,
|
|
masses,
|
|
dt: float,
|
|
n_steps: int,
|
|
softening: float = 0.01
|
|
):
|
|
G = 1.0
|
|
|
|
pos = positions.clone()
|
|
vel = velocities.clone()
|
|
|
|
for _ in range(n_steps):
|
|
diff = pos.unsqueeze(0) - pos.unsqueeze(1)
|
|
|
|
dist_sq = torch.sum(diff ** 2, dim=-1) + softening ** 2
|
|
dist = torch.sqrt(dist_sq)
|
|
dist_cubed = dist_sq * dist
|
|
|
|
dist_cubed = torch.where(dist_cubed == 0, torch.ones_like(dist_cubed), dist_cubed)
|
|
|
|
force_factor = G * masses.unsqueeze(0) / dist_cubed
|
|
|
|
acc = torch.sum(force_factor.unsqueeze(-1) * diff, dim=1)
|
|
|
|
vel = vel + 0.5 * dt * acc
|
|
pos = pos + dt * vel
|
|
vel = vel + 0.5 * dt * acc
|
|
|
|
return pos, vel
|
|
|
|
|
|
def longest_increasing_subsequence_length_torch(arr):
|
|
n = arr.shape[0]
|
|
|
|
if n == 0:
|
|
return 0
|
|
|
|
device = arr.device
|
|
dp = torch.ones(n, device=device, dtype=torch.int64)
|
|
|
|
for i in range(1, n):
|
|
for j in range(i):
|
|
if arr[j] < arr[i]:
|
|
if dp[j] + 1 > dp[i]:
|
|
dp[i] = dp[j] + 1
|
|
|
|
return int(torch.max(dp).item())
|
|
|
|
|
|
def _tridiagonal_forward_cond_tf(i, _c_prime, _d_prime, n, _a, _b, _c, _d):
|
|
return i < n - 1
|
|
|
|
|
|
def _tridiagonal_forward_body_tf(i, c_prime, d_prime, n, a, b, c, d):
|
|
c_prev = c_prime[i - 1]
|
|
d_prev = d_prime[i - 1]
|
|
denom = b[i] - a[i - 1] * c_prev
|
|
c_val = c[i] / denom
|
|
d_val = (d[i] - a[i - 1] * d_prev) / denom
|
|
c_prime = tf.tensor_scatter_nd_update(c_prime, tf.reshape(i, [1, 1]), tf.reshape(c_val, [1]))
|
|
d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(i, [1, 1]), tf.reshape(d_val, [1]))
|
|
return i + 1, c_prime, d_prime, n, a, b, c, d
|
|
|
|
|
|
def _tridiagonal_back_cond_tf(i, _x, _c_prime, _d_prime):
|
|
return i >= 0
|
|
|
|
|
|
def _tridiagonal_back_body_tf(i, x, c_prime, d_prime):
|
|
x_next = x[i + 1]
|
|
x_val = d_prime[i] - c_prime[i] * x_next
|
|
x = tf.tensor_scatter_nd_update(x, tf.reshape(i, [1, 1]), tf.reshape(x_val, [1]))
|
|
return i - 1, x, c_prime, d_prime
|
|
|
|
|
|
def tridiagonal_solve_tf(a, b, c, d):
|
|
n = tf.shape(b)[0]
|
|
dtype = b.dtype
|
|
|
|
c_prime = tf.zeros([n - 1], dtype=dtype)
|
|
d_prime = tf.zeros([n], dtype=dtype)
|
|
|
|
c_prime = tf.tensor_scatter_nd_update(c_prime, [[0]], tf.reshape(c[0] / b[0], [1]))
|
|
d_prime = tf.tensor_scatter_nd_update(d_prime, [[0]], tf.reshape(d[0] / b[0], [1]))
|
|
|
|
_, c_prime, d_prime, _, _, _, _, _ = tf.while_loop(
|
|
_tridiagonal_forward_cond_tf,
|
|
_tridiagonal_forward_body_tf,
|
|
[1, c_prime, d_prime, n, a, b, c, d]
|
|
)
|
|
|
|
c_last = c_prime[n - 2]
|
|
d_prev = d_prime[n - 2]
|
|
denom = b[n - 1] - a[n - 2] * c_last
|
|
d_last = (d[n - 1] - a[n - 2] * d_prev) / denom
|
|
d_prime = tf.tensor_scatter_nd_update(d_prime, tf.reshape(n - 1, [1, 1]), tf.reshape(d_last, [1]))
|
|
|
|
x = tf.zeros([n], dtype=dtype)
|
|
x = tf.tensor_scatter_nd_update(x, tf.reshape(n - 1, [1, 1]), tf.reshape(d_prime[n - 1], [1]))
|
|
|
|
_, x, _, _ = tf.while_loop(
|
|
_tridiagonal_back_cond_tf,
|
|
_tridiagonal_back_body_tf,
|
|
[n - 2, x, c_prime, d_prime]
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
def _leapfrog_compute_accelerations_tf(pos, masses, softening, G):
|
|
diff = tf.expand_dims(pos, 0) - tf.expand_dims(pos, 1)
|
|
|
|
dist_sq = tf.reduce_sum(diff ** 2, axis=-1) + softening ** 2
|
|
dist = tf.sqrt(dist_sq)
|
|
dist_cubed = dist_sq * dist
|
|
|
|
dist_cubed = tf.where(dist_cubed == 0, tf.ones_like(dist_cubed), dist_cubed)
|
|
|
|
force_factor = G * tf.expand_dims(masses, 0) / dist_cubed
|
|
|
|
acc = tf.reduce_sum(tf.expand_dims(force_factor, -1) * diff, axis=1)
|
|
return acc
|
|
|
|
|
|
def _leapfrog_step_body_tf(i, pos, vel, masses, softening, dt, n_steps):
|
|
G = 1.0
|
|
acc = _leapfrog_compute_accelerations_tf(pos, masses, softening, G)
|
|
|
|
vel = vel + 0.5 * dt * acc
|
|
pos = pos + dt * vel
|
|
vel = vel + 0.5 * dt * acc
|
|
|
|
return i + 1, pos, vel, masses, softening, dt, n_steps
|
|
|
|
|
|
def _leapfrog_step_cond_tf(i, _pos, _vel, _masses, _softening, _dt, n_steps):
|
|
return i < n_steps
|
|
|
|
|
|
def leapfrog_integration_tf(
|
|
positions,
|
|
velocities,
|
|
masses,
|
|
dt: float,
|
|
n_steps: int,
|
|
softening: float = 0.01
|
|
):
|
|
dt = tf.constant(dt, dtype=positions.dtype)
|
|
softening = tf.constant(softening, dtype=positions.dtype)
|
|
|
|
_, final_pos, final_vel, _, _, _, _ = tf.while_loop(
|
|
_leapfrog_step_cond_tf,
|
|
_leapfrog_step_body_tf,
|
|
[0, positions, velocities, masses, softening, dt, n_steps]
|
|
)
|
|
|
|
return final_pos, final_vel
|
|
|
|
|
|
def _lis_inner_body_tf(j, dp_inner, arr, i):
|
|
condition = tf.logical_and(arr[j] < arr[i], dp_inner[j] + 1 > dp_inner[i])
|
|
new_val = tf.where(condition, dp_inner[j] + 1, dp_inner[i])
|
|
indices = tf.reshape(i, [1, 1])
|
|
updates = tf.reshape(new_val, [1])
|
|
dp_updated = tf.tensor_scatter_nd_update(dp_inner, indices, updates)
|
|
return j + 1, dp_updated, arr, i
|
|
|
|
|
|
def _lis_inner_cond_tf(j, _dp_inner, _arr, i):
|
|
return j < i
|
|
|
|
|
|
def _lis_outer_body_tf(i, dp, arr, n):
|
|
_, dp, _, _ = tf.while_loop(
|
|
_lis_inner_cond_tf,
|
|
_lis_inner_body_tf,
|
|
[0, dp, arr, i]
|
|
)
|
|
return i + 1, dp, arr, n
|
|
|
|
|
|
def _lis_outer_cond_tf(i, _dp, _arr, n):
|
|
return i < n
|
|
|
|
|
|
def longest_increasing_subsequence_length_tf(arr):
|
|
n = tf.shape(arr)[0]
|
|
|
|
if n == 0:
|
|
return 0
|
|
|
|
dp = tf.ones(n, dtype=tf.int32)
|
|
|
|
_, dp, _, _ = tf.while_loop(
|
|
_lis_outer_cond_tf,
|
|
_lis_outer_body_tf,
|
|
[1, dp, arr, n]
|
|
)
|
|
|
|
return int(tf.reduce_max(dp))
|