Skip to content

jaxleyverse/tridiax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

49 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tests

tridiax

tridiax implements solvers for tridiagonal systems in jax. All solvers support CPU and GPU, are compatible with jit compilation and can be differentiated with grad.

Implemented solvers

Generally, Thomas algorithm will be faster on CPU whereas the divide and conquer algorithm and Stone's algorithm will be faster on GPU.

Known limitations

Currently, the divide_conquer solver only supports systems whose dimensionality is a power of 2.

Usage

from tridiax import thomas_solve, divide_conquer_solve, stone_solve

dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))
solution = thomas_solve(lower, diag, upper, solve)

If many systems of the same size are solved and the divide and conquer algorithm is used, it helps to precompute the reordering indizes:

from tridiax import divide_conquer_solve, divide_conquer_index

dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))

indexing = divide_conquer_index(dim)
solution = divide_conquer_solve(lower, diag, upper, solve, indexing=indexing)

Installation

tridiax is available on pypi:

pip install tridiax

This will install tridiax with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing tridiax). For example, for NVIDIA GPUs, run

pip install -U "jax[cuda12]"

About

Solvers for tridiagonal systems in JAX.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages