almost ready

This commit is contained in:
aseembits93 2026-01-15 18:58:13 -08:00
parent 3a0e41861c
commit 4d28c1779f
5 changed files with 20 additions and 40 deletions

View file

@ -10,48 +10,28 @@ from jax import lax
def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
n = len(b)
c_prime = np.empty(n - 1, dtype=np.float64)
d_prime = np.empty(n, dtype=np.float64)
x = np.empty(n, dtype=np.float64)
# Create working copies to avoid modifying input
c_prime = np.zeros(n - 1, dtype=np.float64)
d_prime = np.zeros(n, dtype=np.float64)
x = np.zeros(n, dtype=np.float64)
# Alias arrays to local variables to avoid repeated attribute lookups
a_arr = a
b_arr = b
c_arr = c
d_arr = d
cp = c_prime
dp = d_prime
x_arr = x
# First element
prev_cprime = c_arr[0] / b_arr[0]
cp[0] = prev_cprime
prev_dprime = d_arr[0] / b_arr[0]
dp[0] = prev_dprime
# Forward sweep (compute c_prime and d_prime)
# Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1]
c_prime[0] = c[0] / b[0]
d_prime[0] = d[0] / b[0]
for i in range(1, n - 1):
ai_1 = a_arr[i - 1]
denom = b_arr[i] - ai_1 * prev_cprime
curr_cprime = c_arr[i] / denom
curr_dprime = (d_arr[i] - ai_1 * prev_dprime) / denom
cp[i] = curr_cprime
dp[i] = curr_dprime
prev_cprime = curr_cprime
prev_dprime = curr_dprime
denom = b[i] - a[i - 1] * c_prime[i - 1]
c_prime[i] = c[i] / denom
d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom
# Last d_prime entry
denom = b_arr[n - 1] - a_arr[n - 2] * prev_cprime
dp[n - 1] = (d_arr[n - 1] - a_arr[n - 2] * prev_dprime) / denom
# Last row of forward sweep
denom = b[n - 1] - a[n - 2] * c_prime[n - 2]
d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom
# Back substitution using a scalar for the "next x" value
prev_x = dp[n - 1]
x_arr[n - 1] = prev_x
# Back substitution - sequential dependency: x[i] depends on x[i+1]
x[n - 1] = d_prime[n - 1]
for i in range(n - 2, -1, -1):
xi = dp[i] - cp[i] * prev_x
x_arr[i] = xi
prev_x = xi
x[i] = d_prime[i] - c_prime[i] * x[i + 1]
return x

View file

@ -10,7 +10,7 @@ import pytest
jax = pytest.importorskip("jax")
import jax.numpy as jnp
from code_to_optimize.sample_jit_code import (
from code_to_optimize.sample_code import (
leapfrog_integration_jax,
longest_increasing_subsequence_length_jax,
tridiagonal_solve_jax,

View file

@ -1,7 +1,7 @@
import numpy as np
import pytest
from code_to_optimize.sample_jit_code import (
from code_to_optimize.sample_code import (
leapfrog_integration,
longest_increasing_subsequence_length,
tridiagonal_solve,

View file

@ -9,7 +9,7 @@ import pytest
tf = pytest.importorskip("tensorflow")
from code_to_optimize.sample_jit_code import (
from code_to_optimize.sample_code import (
leapfrog_integration_tf,
longest_increasing_subsequence_length_tf,
tridiagonal_solve_tf,

View file

@ -9,7 +9,7 @@ import pytest
torch = pytest.importorskip("torch")
from code_to_optimize.sample_jit_code import (
from code_to_optimize.sample_code import (
leapfrog_integration_torch,
longest_increasing_subsequence_length_torch,
tridiagonal_solve_torch,