tridiax
implements solvers for tridiagonal systems in jax. All solvers support CPU and GPU, are compatible with jit
compilation and can be differentiated with grad
.
Generally, Thomas algorithm will be faster on CPU whereas the divide and conquer algorithm and Stone's algorithm will be faster on GPU.
Currently, the divide_conquer
solver only supports systems whose dimensionality is a power of 2
.
from tridiax import thomas_solve, divide_conquer_solve, stone_solve
dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))
solution = thomas_solve(lower, diag, upper, solve)
If many systems of the same size are solved and the divide and conquer algorithm is used, it helps to precompute the reordering indizes:
from tridiax import divide_conquer_solve, divide_conquer_index
dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))
indexing = divide_conquer_index(dim)
solution = divide_conquer_solve(lower, diag, upper, solve, indexing=indexing)
tridiax
is available on pypi
:
pip install tridiax
This will install tridiax
with CPU support. If you want GPU support, follow the instructions on the JAX
github repository to install JAX
with GPU support (in addition to installing tridiax). For example, for NVIDIA GPUs, run
pip install -U "jax[cuda12]"