Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fnal touches #5

Merged
merged 9 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# GLOBALS #
#################################################################################

PROJECT_NAME = jaxman
PROJECT_NAME = kaxman
PYTHON_VERSION = 3.10
PYTHON_INTERPRETER = python

Expand Down
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
# About the project
jaxman is yet another library that implements the [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter). This library is mainly intended as a self-study project in which to learn to [JAX](https://github.com/google/jax). The implementation is inspired by both [pykalman](https://pykalman.github.io/) and [simdkalman](https://github.com/oseiskar/simdkalman), and like simdkalman implements support for vectorized inference across multiple timeseries. As it's implemented in JAX and utilizes its JIT capabilities, we achieve blazing fast inference for 1,000s (sometimes even 100,000s) of parallel timeseries.
kaxman is yet another library that implements the [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter). The
library is built on top of JAX and is designed to be fast and efficient. The library is still in its early stages and
is not yet feature complete. Some of the features include:
- JIT:able Kalman filter class.
- Support for fully/partially missing observations via inflation of variance.
- Support for time-varying state transition and observation matrices.
- Support for time-varying process and observation noise covariance matrices.
- Support for noise transform, e.g. having the same noise for multiple states.
- Rauch-Tung-Striebel smoother.


# Getting started
Follow the below instructions in order to get started with jaxman.
Follow the below instructions in order to get started with kaxman.

## Installation
The library is currently not available on pypi, and there are currently no plans on releasing it there, so install it via
The library is currently not available on pypi, so install it via
```
https://github.com/tingiskhan/jaxman
https://github.com/tingiskhan/kaxman
```

# Usage
Expand Down
2 changes: 1 addition & 1 deletion jaxman/__init__.py → kaxman/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .kalman_filter import * # noqa: F403

__version__ = "0.1.0"
__version__ = "0.2.0"
33 changes: 14 additions & 19 deletions jaxman/kalman_filter.py → kaxman/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,21 @@
from .results import FilterResult, SmoothingResult


def _inflate_missing(non_valid_mask: jnp.ndarray, r: jnp.ndarray, missing_value: float = 1e12) -> jnp.ndarray:
def _inflate_missing(non_valid_mask: jnp.ndarray, r: jnp.ndarray, inflation: float = 1e12) -> jnp.ndarray:
"""
Masks missing dimensions by zeroing corresponding rows of H and inflating the diagonal of R.
Masks missing dimensions by inflating the diagonal of R.

Args:
non_valid_mask: Boolean mask of shape (obs_dim,) indicating missing dimensions.
r: Observation covariance matrix of shape (obs_dim, obs_dim).
missing_value: Large scalar to add on the diagonal for missing dimensions.
inflation: Large scalar to add on the diagonal for missing dimensions.

Returns:
A tuple of:
- obs_masked: Same shape as obs, with missing entries replaced by 0.0.
- H_masked: Same shape as H, rows zeroed out for missing dimensions.
- r_masked: Same shape as R, diagonal entries inflated for missing dimensions.
"""

diag_inflation = jnp.where(non_valid_mask, missing_value, 0.0)
diag_inflation = jnp.where(non_valid_mask, inflation, 0.0)
r_masked = r + jnp.diag(diag_inflation)

return r_masked
Expand Down Expand Up @@ -171,20 +169,14 @@ def _get_observation_offset(self, t: int) -> jnp.ndarray:

return self.observation_offset

def _get_noise_transform(self, t: int) -> jnp.ndarray:
if callable(self.noise_transform):
return self.noise_transform(t)

return self.noise_transform

def _predict(self, mean: jnp.ndarray, cov: jnp.ndarray, t: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
F_t = self._get_transition_matrix(t, mean)
Q_t = self._get_transition_cov(t)
f_t = self._get_transition_matrix(t, mean)
q_t = self._get_transition_cov(t)
b_t = self._get_transition_offset(t)
G_t = self._get_noise_transform(t)
g_t = self.noise_transform

mean_pred = F_t @ mean + b_t
cov_pred = F_t @ cov @ F_t.T + G_t @ Q_t @ G_t.T
mean_pred = f_t @ mean + b_t
cov_pred = f_t @ cov @ f_t.T + g_t @ q_t @ g_t.T

return mean_pred, cov_pred

Expand Down Expand Up @@ -219,7 +211,7 @@ def scan_fn(carry, obs_t):
nan_mask = jnp.isnan(obs_t)

# TODO: need to verify this...
r_t = _inflate_missing(nan_mask, r_t, missing_value=self.variance_inflation)
r_t = _inflate_missing(nan_mask, r_t, inflation=self.variance_inflation)

y_pred_mean = h_t @ x_pred_mean + d_t
y_pred_cov = h_t @ x_pred_cov @ h_t.T + r_t
Expand Down Expand Up @@ -312,7 +304,7 @@ def sample_step(carry, _):
f_t = self._get_transition_matrix(t, x_prev)
q_t = self._get_transition_cov(t)
b_t = self._get_transition_offset(t)
g_t = self._get_noise_transform(t)
g_t = self.noise_transform
h_t = self._get_observation_matrix(t)
r_t = self._get_observation_cov(t)
d_t = self._get_observation_offset(t)
Expand All @@ -338,3 +330,6 @@ def init_state(key):
_, (xs, ys) = lax.scan(sample_step, init_carry, None, length=num_timesteps)

return xs, ys


__all__ = ["KalmanFilter"]
File renamed without changes.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test = [


[tool.setuptools.packages.find]
include = ["jaxman*"]
include = ["kaxman*"]

[tool.black]
line-length = 120
Expand All @@ -58,7 +58,7 @@ exclude = '''
line-length = 120

[tool.bumpver]
current_version = "0.1.0"
current_version = "0.2.0"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "bump version {old_version} -> {new_version}"
commit = true
Expand All @@ -72,12 +72,12 @@ tag_scope = "default"
'current_version = "{version}"',
]

"jaxman/__init__.py" = [
"kaxman/__init__.py" = [
'__version__ = "{version}"'
]

[tool.setuptools.dynamic]
version = {attr = "jaxman.__version__"}
version = {attr = "kaxman.__version__"}

[tool.pytest.ini_options]
pythonpath = ["."]
Binary file removed static/filtering.jpg
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from numpy.testing import assert_allclose
from pykalman import KalmanFilter as PyKalman
from jaxman import KalmanFilter
from kaxman import KalmanFilter


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jaxman import KalmanFilter
from kaxman import KalmanFilter


@pytest.fixture
Expand Down
Loading