Skip to content

Commit

Permalink
Feature/fnal touches (#5)
Browse files Browse the repository at this point in the history
* Move and fx docs

* Rename

* Fix all

* REname project

* Readme

* Remove irrelevant picture

* Fix

* Rename

* bump version 0.1.0 -> 0.2.0
  • Loading branch information
tingiskhan authored Jan 4, 2025
1 parent 5460cc5 commit cf05b25
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# GLOBALS #
#################################################################################

PROJECT_NAME = jaxman
PROJECT_NAME = kaxman
PYTHON_VERSION = 3.10
PYTHON_INTERPRETER = python

Expand Down
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
# About the project
jaxman is yet another library that implements the [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter). This library is mainly intended as a self-study project in which to learn to [JAX](https://github.com/google/jax). The implementation is inspired by both [pykalman](https://pykalman.github.io/) and [simdkalman](https://github.com/oseiskar/simdkalman), and like simdkalman implements support for vectorized inference across multiple timeseries. As it's implemented in JAX and utilizes its JIT capabilities, we achieve blazing fast inference for 1,000s (sometimes even 100,000s) of parallel timeseries.
kaxman is yet another library that implements the [Kalman filter](https://en.wikipedia.org/wiki/Kalman_filter). The
library is built on top of JAX and is designed to be fast and efficient. The library is still in its early stages and
is not yet feature complete. Some of the features include:
- JIT:able Kalman filter class.
- Support for fully/partially missing observations via inflation of variance.
- Support for time-varying state transition and observation matrices.
- Support for time-varying process and observation noise covariance matrices.
- Support for noise transform, e.g. having the same noise for multiple states.
- Rauch-Tung-Striebel smoother.


# Getting started
Follow the below instructions in order to get started with jaxman.
Follow the below instructions in order to get started with kaxman.

## Installation
The library is currently not available on pypi, and there are currently no plans on releasing it there, so install it via
The library is currently not available on pypi, so install it via
```
https://github.com/tingiskhan/jaxman
https://github.com/tingiskhan/kaxman
```

# Usage
Expand Down
2 changes: 1 addition & 1 deletion jaxman/__init__.py → kaxman/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .kalman_filter import * # noqa: F403

__version__ = "0.1.0"
__version__ = "0.2.0"
33 changes: 14 additions & 19 deletions jaxman/kalman_filter.py → kaxman/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,21 @@
from .results import FilterResult, SmoothingResult


def _inflate_missing(non_valid_mask: jnp.ndarray, r: jnp.ndarray, missing_value: float = 1e12) -> jnp.ndarray:
def _inflate_missing(non_valid_mask: jnp.ndarray, r: jnp.ndarray, inflation: float = 1e12) -> jnp.ndarray:
"""
Masks missing dimensions by zeroing corresponding rows of H and inflating the diagonal of R.
Masks missing dimensions by inflating the diagonal of R.
Args:
non_valid_mask: Boolean mask of shape (obs_dim,) indicating missing dimensions.
r: Observation covariance matrix of shape (obs_dim, obs_dim).
missing_value: Large scalar to add on the diagonal for missing dimensions.
inflation: Large scalar to add on the diagonal for missing dimensions.
Returns:
A tuple of:
- obs_masked: Same shape as obs, with missing entries replaced by 0.0.
- H_masked: Same shape as H, rows zeroed out for missing dimensions.
- r_masked: Same shape as R, diagonal entries inflated for missing dimensions.
"""

diag_inflation = jnp.where(non_valid_mask, missing_value, 0.0)
diag_inflation = jnp.where(non_valid_mask, inflation, 0.0)
r_masked = r + jnp.diag(diag_inflation)

return r_masked
Expand Down Expand Up @@ -171,20 +169,14 @@ def _get_observation_offset(self, t: int) -> jnp.ndarray:

return self.observation_offset

def _get_noise_transform(self, t: int) -> jnp.ndarray:
if callable(self.noise_transform):
return self.noise_transform(t)

return self.noise_transform

def _predict(self, mean: jnp.ndarray, cov: jnp.ndarray, t: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
F_t = self._get_transition_matrix(t, mean)
Q_t = self._get_transition_cov(t)
f_t = self._get_transition_matrix(t, mean)
q_t = self._get_transition_cov(t)
b_t = self._get_transition_offset(t)
G_t = self._get_noise_transform(t)
g_t = self.noise_transform

mean_pred = F_t @ mean + b_t
cov_pred = F_t @ cov @ F_t.T + G_t @ Q_t @ G_t.T
mean_pred = f_t @ mean + b_t
cov_pred = f_t @ cov @ f_t.T + g_t @ q_t @ g_t.T

return mean_pred, cov_pred

Expand Down Expand Up @@ -219,7 +211,7 @@ def scan_fn(carry, obs_t):
nan_mask = jnp.isnan(obs_t)

# TODO: need to verify this...
r_t = _inflate_missing(nan_mask, r_t, missing_value=self.variance_inflation)
r_t = _inflate_missing(nan_mask, r_t, inflation=self.variance_inflation)

y_pred_mean = h_t @ x_pred_mean + d_t
y_pred_cov = h_t @ x_pred_cov @ h_t.T + r_t
Expand Down Expand Up @@ -312,7 +304,7 @@ def sample_step(carry, _):
f_t = self._get_transition_matrix(t, x_prev)
q_t = self._get_transition_cov(t)
b_t = self._get_transition_offset(t)
g_t = self._get_noise_transform(t)
g_t = self.noise_transform
h_t = self._get_observation_matrix(t)
r_t = self._get_observation_cov(t)
d_t = self._get_observation_offset(t)
Expand All @@ -338,3 +330,6 @@ def init_state(key):
_, (xs, ys) = lax.scan(sample_step, init_carry, None, length=num_timesteps)

return xs, ys


__all__ = ["KalmanFilter"]
File renamed without changes.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test = [


[tool.setuptools.packages.find]
include = ["jaxman*"]
include = ["kaxman*"]

[tool.black]
line-length = 120
Expand All @@ -58,7 +58,7 @@ exclude = '''
line-length = 120

[tool.bumpver]
current_version = "0.1.0"
current_version = "0.2.0"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "bump version {old_version} -> {new_version}"
commit = true
Expand All @@ -72,12 +72,12 @@ tag_scope = "default"
'current_version = "{version}"',
]

"jaxman/__init__.py" = [
"kaxman/__init__.py" = [
'__version__ = "{version}"'
]

[tool.setuptools.dynamic]
version = {attr = "jaxman.__version__"}
version = {attr = "kaxman.__version__"}

[tool.pytest.ini_options]
pythonpath = ["."]
Binary file removed static/filtering.jpg
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from numpy.testing import assert_allclose
from pykalman import KalmanFilter as PyKalman
from jaxman import KalmanFilter
from kaxman import KalmanFilter


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from jaxman import KalmanFilter
from kaxman import KalmanFilter


@pytest.fixture
Expand Down

0 comments on commit cf05b25

Please sign in to comment.