diff --git a/Makefile b/Makefile index 6a41619..9a9e927 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ # GLOBALS # ################################################################################# -PROJECT_NAME = jaxman +PROJECT_NAME = kaxman PYTHON_VERSION = 3.10 PYTHON_INTERPRETER = python diff --git a/README.md b/README.md index 01a5c84..5362dcf 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,22 @@ # About the project -jaxman is yet another library that implements the [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter). This library is mainly intended as a self-study project in which to learn to [JAX](https://github.com/google/jax). The implementation is inspired by both [pykalman](https://pykalman.github.io/) and [simdkalman](https://github.com/oseiskar/simdkalman), and like simdkalman implements support for vectorized inference across multiple timeseries. As it's implemented in JAX and utilizes its JIT capabilities, we achieve blazing fast inference for 1,000s (sometimes even 100,000s) of parallel timeseries. +kaxman is yet another library that implements the [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter). The +library is built on top of JAX and is designed to be fast and efficient. The library is still in its early stages and +is not yet feature complete. Some of the features include: +- JIT:able Kalman filter class. +- Support for fully/partially missing observations via inflation of variance. +- Support for time-varying state transition and observation matrices. +- Support for time-varying process and observation noise covariance matrices. +- Support for noise transform, e.g. having the same noise for multiple states. +- Rauch-Tung-Striebel smoother. + # Getting started -Follow the below instructions in order to get started with jaxman. +Follow the below instructions in order to get started with kaxman. ## Installation -The library is currently not available on pypi, and there are currently no plans on releasing it there, so install it via +The library is currently not available on pypi, so install it via ``` -https://github.com/tingiskhan/jaxman +https://github.com/tingiskhan/kaxman ``` # Usage diff --git a/jaxman/__init__.py b/kaxman/__init__.py similarity index 66% rename from jaxman/__init__.py rename to kaxman/__init__.py index 107ec21..13516ea 100644 --- a/jaxman/__init__.py +++ b/kaxman/__init__.py @@ -1,3 +1,3 @@ from .kalman_filter import * # noqa: F403 -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/jaxman/kalman_filter.py b/kaxman/kalman_filter.py similarity index 92% rename from jaxman/kalman_filter.py rename to kaxman/kalman_filter.py index 2294b03..402e3d2 100644 --- a/jaxman/kalman_filter.py +++ b/kaxman/kalman_filter.py @@ -10,23 +10,21 @@ from .results import FilterResult, SmoothingResult -def _inflate_missing(non_valid_mask: jnp.ndarray, r: jnp.ndarray, missing_value: float = 1e12) -> jnp.ndarray: +def _inflate_missing(non_valid_mask: jnp.ndarray, r: jnp.ndarray, inflation: float = 1e12) -> jnp.ndarray: """ - Masks missing dimensions by zeroing corresponding rows of H and inflating the diagonal of R. + Masks missing dimensions by inflating the diagonal of R. Args: non_valid_mask: Boolean mask of shape (obs_dim,) indicating missing dimensions. r: Observation covariance matrix of shape (obs_dim, obs_dim). - missing_value: Large scalar to add on the diagonal for missing dimensions. + inflation: Large scalar to add on the diagonal for missing dimensions. Returns: A tuple of: - - obs_masked: Same shape as obs, with missing entries replaced by 0.0. - - H_masked: Same shape as H, rows zeroed out for missing dimensions. - r_masked: Same shape as R, diagonal entries inflated for missing dimensions. """ - diag_inflation = jnp.where(non_valid_mask, missing_value, 0.0) + diag_inflation = jnp.where(non_valid_mask, inflation, 0.0) r_masked = r + jnp.diag(diag_inflation) return r_masked @@ -171,20 +169,14 @@ def _get_observation_offset(self, t: int) -> jnp.ndarray: return self.observation_offset - def _get_noise_transform(self, t: int) -> jnp.ndarray: - if callable(self.noise_transform): - return self.noise_transform(t) - - return self.noise_transform - def _predict(self, mean: jnp.ndarray, cov: jnp.ndarray, t: int) -> Tuple[jnp.ndarray, jnp.ndarray]: - F_t = self._get_transition_matrix(t, mean) - Q_t = self._get_transition_cov(t) + f_t = self._get_transition_matrix(t, mean) + q_t = self._get_transition_cov(t) b_t = self._get_transition_offset(t) - G_t = self._get_noise_transform(t) + g_t = self.noise_transform - mean_pred = F_t @ mean + b_t - cov_pred = F_t @ cov @ F_t.T + G_t @ Q_t @ G_t.T + mean_pred = f_t @ mean + b_t + cov_pred = f_t @ cov @ f_t.T + g_t @ q_t @ g_t.T return mean_pred, cov_pred @@ -219,7 +211,7 @@ def scan_fn(carry, obs_t): nan_mask = jnp.isnan(obs_t) # TODO: need to verify this... - r_t = _inflate_missing(nan_mask, r_t, missing_value=self.variance_inflation) + r_t = _inflate_missing(nan_mask, r_t, inflation=self.variance_inflation) y_pred_mean = h_t @ x_pred_mean + d_t y_pred_cov = h_t @ x_pred_cov @ h_t.T + r_t @@ -312,7 +304,7 @@ def sample_step(carry, _): f_t = self._get_transition_matrix(t, x_prev) q_t = self._get_transition_cov(t) b_t = self._get_transition_offset(t) - g_t = self._get_noise_transform(t) + g_t = self.noise_transform h_t = self._get_observation_matrix(t) r_t = self._get_observation_cov(t) d_t = self._get_observation_offset(t) @@ -338,3 +330,6 @@ def init_state(key): _, (xs, ys) = lax.scan(sample_step, init_carry, None, length=num_timesteps) return xs, ys + + +__all__ = ["KalmanFilter"] diff --git a/jaxman/results.py b/kaxman/results.py similarity index 100% rename from jaxman/results.py rename to kaxman/results.py diff --git a/pyproject.toml b/pyproject.toml index ef29388..322f6b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ test = [ [tool.setuptools.packages.find] -include = ["jaxman*"] +include = ["kaxman*"] [tool.black] line-length = 120 @@ -58,7 +58,7 @@ exclude = ''' line-length = 120 [tool.bumpver] -current_version = "0.1.0" +current_version = "0.2.0" version_pattern = "MAJOR.MINOR.PATCH" commit_message = "bump version {old_version} -> {new_version}" commit = true @@ -72,12 +72,12 @@ tag_scope = "default" 'current_version = "{version}"', ] -"jaxman/__init__.py" = [ +"kaxman/__init__.py" = [ '__version__ = "{version}"' ] [tool.setuptools.dynamic] -version = {attr = "jaxman.__version__"} +version = {attr = "kaxman.__version__"} [tool.pytest.ini_options] pythonpath = ["."] \ No newline at end of file diff --git a/static/filtering.jpg b/static/filtering.jpg deleted file mode 100644 index b1b7369..0000000 Binary files a/static/filtering.jpg and /dev/null differ diff --git a/tests/test_kalman.py b/tests/test_kalman.py index 487d8ad..98b2596 100644 --- a/tests/test_kalman.py +++ b/tests/test_kalman.py @@ -4,7 +4,7 @@ import numpy as np from numpy.testing import assert_allclose from pykalman import KalmanFilter as PyKalman -from jaxman import KalmanFilter +from kaxman import KalmanFilter @pytest.fixture diff --git a/tests/test_numpyro.py b/tests/test_numpyro.py index 86de424..1695d4e 100644 --- a/tests/test_numpyro.py +++ b/tests/test_numpyro.py @@ -4,7 +4,7 @@ import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS -from jaxman import KalmanFilter +from kaxman import KalmanFilter @pytest.fixture