mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
almost ready
This commit is contained in:
parent
3a0e41861c
commit
4d28c1779f
5 changed files with 20 additions and 40 deletions
|
|
@ -10,48 +10,28 @@ from jax import lax
|
|||
def tridiagonal_solve(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
|
||||
n = len(b)
|
||||
|
||||
c_prime = np.empty(n - 1, dtype=np.float64)
|
||||
d_prime = np.empty(n, dtype=np.float64)
|
||||
x = np.empty(n, dtype=np.float64)
|
||||
# Create working copies to avoid modifying input
|
||||
c_prime = np.zeros(n - 1, dtype=np.float64)
|
||||
d_prime = np.zeros(n, dtype=np.float64)
|
||||
x = np.zeros(n, dtype=np.float64)
|
||||
|
||||
# Alias arrays to local variables to avoid repeated attribute lookups
|
||||
a_arr = a
|
||||
b_arr = b
|
||||
c_arr = c
|
||||
d_arr = d
|
||||
cp = c_prime
|
||||
dp = d_prime
|
||||
x_arr = x
|
||||
|
||||
# First element
|
||||
prev_cprime = c_arr[0] / b_arr[0]
|
||||
cp[0] = prev_cprime
|
||||
prev_dprime = d_arr[0] / b_arr[0]
|
||||
dp[0] = prev_dprime
|
||||
|
||||
# Forward sweep (compute c_prime and d_prime)
|
||||
# Forward sweep - sequential dependency: c_prime[i] depends on c_prime[i-1]
|
||||
c_prime[0] = c[0] / b[0]
|
||||
d_prime[0] = d[0] / b[0]
|
||||
|
||||
for i in range(1, n - 1):
|
||||
ai_1 = a_arr[i - 1]
|
||||
denom = b_arr[i] - ai_1 * prev_cprime
|
||||
curr_cprime = c_arr[i] / denom
|
||||
curr_dprime = (d_arr[i] - ai_1 * prev_dprime) / denom
|
||||
cp[i] = curr_cprime
|
||||
dp[i] = curr_dprime
|
||||
prev_cprime = curr_cprime
|
||||
prev_dprime = curr_dprime
|
||||
denom = b[i] - a[i - 1] * c_prime[i - 1]
|
||||
c_prime[i] = c[i] / denom
|
||||
d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom
|
||||
|
||||
# Last d_prime entry
|
||||
denom = b_arr[n - 1] - a_arr[n - 2] * prev_cprime
|
||||
dp[n - 1] = (d_arr[n - 1] - a_arr[n - 2] * prev_dprime) / denom
|
||||
# Last row of forward sweep
|
||||
denom = b[n - 1] - a[n - 2] * c_prime[n - 2]
|
||||
d_prime[n - 1] = (d[n - 1] - a[n - 2] * d_prime[n - 2]) / denom
|
||||
|
||||
# Back substitution using a scalar for the "next x" value
|
||||
prev_x = dp[n - 1]
|
||||
x_arr[n - 1] = prev_x
|
||||
# Back substitution - sequential dependency: x[i] depends on x[i+1]
|
||||
x[n - 1] = d_prime[n - 1]
|
||||
for i in range(n - 2, -1, -1):
|
||||
xi = dp[i] - cp[i] * prev_x
|
||||
x_arr[i] = xi
|
||||
prev_x = xi
|
||||
x[i] = d_prime[i] - c_prime[i] * x[i + 1]
|
||||
|
||||
return x
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
jax = pytest.importorskip("jax")
|
||||
import jax.numpy as jnp
|
||||
|
||||
from code_to_optimize.sample_jit_code import (
|
||||
from code_to_optimize.sample_code import (
|
||||
leapfrog_integration_jax,
|
||||
longest_increasing_subsequence_length_jax,
|
||||
tridiagonal_solve_jax,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from code_to_optimize.sample_jit_code import (
|
||||
from code_to_optimize.sample_code import (
|
||||
leapfrog_integration,
|
||||
longest_increasing_subsequence_length,
|
||||
tridiagonal_solve,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
|
||||
tf = pytest.importorskip("tensorflow")
|
||||
|
||||
from code_to_optimize.sample_jit_code import (
|
||||
from code_to_optimize.sample_code import (
|
||||
leapfrog_integration_tf,
|
||||
longest_increasing_subsequence_length_tf,
|
||||
tridiagonal_solve_tf,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
|
||||
torch = pytest.importorskip("torch")
|
||||
|
||||
from code_to_optimize.sample_jit_code import (
|
||||
from code_to_optimize.sample_code import (
|
||||
leapfrog_integration_torch,
|
||||
longest_increasing_subsequence_length_torch,
|
||||
tridiagonal_solve_torch,
|
||||
|
|
|
|||
Loading…
Reference in a new issue