Skip to content

Latest commit

 

History

History
58 lines (41 loc) · 2.08 KB

README.md

File metadata and controls

58 lines (41 loc) · 2.08 KB

Tests

tridiax

tridiax implements solvers for tridiagonal systems in jax. All solvers support CPU and GPU, are compatible with jit compilation and can be differentiated with grad.

Implemented solvers

Generally, Thomas algorithm will be faster on CPU whereas the divide and conquer algorithm and Stone's algorithm will be faster on GPU.

Known limitations

Currently, the divide_conquer solver only supports systems whose dimensionality is a power of 2.

Usage

from tridiax import thomas_solve, divide_conquer_solve, stone_solve

dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))
solution = thomas_solve(lower, diag, upper, solve)

If many systems of the same size are solved and the divide and conquer algorithm is used, it helps to precompute the reordering indizes:

from tridiax import divide_conquer_solve, divide_conquer_index

dim = 1024
diag = jnp.asarray(np.random.randn(dim))
upper = jnp.asarray(np.random.randn(dim - 1))
lower = jnp.asarray(np.random.randn(dim - 1))
solve = jnp.asarray(np.random.randn(dim))

indexing = divide_conquer_index(dim)
solution = divide_conquer_solve(lower, diag, upper, solve, indexing=indexing)

Installation

tridiax is available on pypi:

pip install tridiax

This will install tridiax with CPU support. If you want GPU support, follow the instructions on the JAX github repository to install JAX with GPU support (in addition to installing tridiax). For example, for NVIDIA GPUs, run

pip install -U "jax[cuda12]"