Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of rao-blackwellised particle filter (rbpf) and rbpf with optimal resampling. #323

Merged
merged 13 commits into from
Nov 1, 2023
84 changes: 19 additions & 65 deletions docs/notebooks/linear_gaussian_ssm/lgssm_learning.ipynb

Large diffs are not rendered by default.

299 changes: 299 additions & 0 deletions docs/notebooks/slds/rbpf_maneuver.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking a maneuvering target using the RBPF"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'dynamax.slds'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb Cell 2\u001b[0m line \u001b[0;36m8\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m sys\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mappend(\u001b[39m'\u001b[39m\u001b[39m/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/dynamax\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mjax\u001b[39;00m \u001b[39mimport\u001b[39;00m vmap, tree_map, jit\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdynamax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mslds\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39minference\u001b[39;00m \u001b[39mimport\u001b[39;00m ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=8'>9</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdynamax\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mslds\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodels\u001b[39;00m \u001b[39mimport\u001b[39;00m SLDS\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax/docs/notebooks/slds/rbpf_maneuver.ipynb#W1sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# import MVN from tfd\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'dynamax.slds'"
]
}
],
"source": [
"import dynamax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"from functools import partial\n",
"import sys \n",
"sys.path.append('/Users/kostastsampourakis/Desktop/code/Python/projects/dynamax')\n",
"from jax import vmap, tree_map, jit\n",
"from dynamax.slds.inference import ParamsSLDS, LGParamsSLDS, DiscreteParamsSLDS, rbpfilter, rbpfilter_optimal\n",
"from dynamax.slds.models import SLDS\n",
"# import MVN from tfd\n",
"from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN\n",
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"from functools import partial\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from jax.scipy.special import logit\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import jax"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simulate Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_states = 3\n",
"num_particles = 1000\n",
"state_dim = 4\n",
"emission_dim = 4\n",
"\n",
"TT = 0.1\n",
"A = jnp.array([[1, TT, 0, 0],\n",
" [0, 1, 0, 0],\n",
" [0, 0, 1, TT],\n",
" [0, 0, 0, 1]])\n",
"\n",
"\n",
"B1 = jnp.array([0, 0, 0, 0])\n",
"B2 = jnp.array([-1.225, -0.35, 1.225, 0.35])\n",
"B3 = jnp.array([1.225, 0.35, -1.225, -0.35])\n",
"B = jnp.stack([B1, B2, B3], axis=0)\n",
"\n",
"Q = 0.2 * jnp.eye(4)\n",
"R = 10.0 * jnp.diag(jnp.array([2, 1, 2, 1]))\n",
"C = jnp.eye(4)\n",
"\n",
"transition_matrix = jnp.array([\n",
" [0.8, 0.1, 0.1],\n",
" [0.1, 0.8, 0.1],\n",
" [0.1, 0.1, 0.8]\n",
"])\n",
"\n",
"discr_params = DiscreteParamsSLDS(\n",
" initial_distribution=jnp.ones(num_states)/num_states,\n",
" transition_matrix=transition_matrix,\n",
" proposal_transition_matrix=transition_matrix\n",
")\n",
"\n",
"lg_params = LGParamsSLDS(\n",
" initial_mean=jnp.ones(state_dim),\n",
" initial_cov=jnp.eye(state_dim),\n",
" dynamics_weights=A,\n",
" dynamics_cov=Q,\n",
" dynamics_bias=jnp.array([B1, B2, B3]),\n",
" dynamics_input_weights=None,\n",
" emission_weights=C,\n",
" emission_cov=R,\n",
" emission_bias=None,\n",
" emission_input_weights=None\n",
")\n",
"\n",
"pre_params = ParamsSLDS(\n",
" discrete=discr_params,\n",
" linear_gaussian=lg_params\n",
")\n",
"\n",
"params = pre_params.initialize(num_states, state_dim, emission_dim)\n",
"\n",
"## Sample states and emissions\n",
"key, next_key = jr.split(jr.PRNGKey(1))\n",
"slds = SLDS(num_states, state_dim, emission_dim)\n",
"dstates, cstates, emissions = slds.sample(params, key, 100)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Filtering"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"out = rbpfilter_optimal(num_particles, params, emissions, next_key)\n",
"filtered_means = out['means']\n",
"weights = out['weights']\n",
"sampled_dstates = out['states']\n",
"post_mean = jnp.einsum(\"ts,tsm->tm\", weights, filtered_means)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f = lambda t: jnp.array([jnp.sum(weights[t, jnp.where(sampled_dstates[t]==st)]) for st in range(num_states)])\n",
"p_est = jnp.array(list(map(f, jnp.arange(100))))\n",
"est_dstates = jnp.argmax(p_est, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def plot_3d_belief_state(mu_hist, dim, ax, skip=3, npoints=2000, azimuth=-30, elevation=30, h=0.5):\n",
" nsteps = len(mu_hist)\n",
" xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max()\n",
" xrange = np.linspace(xmin, xmax, npoints).reshape(-1, 1)\n",
" res = np.apply_along_axis(lambda X: kdeg(xrange, X[..., None], h), 1, mu_hist)\n",
" densities = res[..., dim]\n",
" for t in range(0, nsteps, skip):\n",
" tloc = t * np.ones(npoints)\n",
" px = densities[t]\n",
" ax.plot(tloc, xrange, px, c=\"tab:blue\", linewidth=1)\n",
" ax.set_zlim(0, 1)\n",
" style3d(ax, 1.8, 1.2, 0.7, 0.8)\n",
" ax.view_init(elevation, azimuth)\n",
" ax.set_xlabel(r\"$t$\", fontsize=13)\n",
" ax.set_ylabel(r\"$x_{\"f\"d={dim}\"\",t}$\", fontsize=13)\n",
" ax.set_zlabel(r\"$p(x_{d, t} \\vert y_{1:t})$\", fontsize=13)\n",
"\n",
"def scale_3d(ax, x_scale, y_scale, z_scale, factor): \n",
" scale=np.diag(np.array([x_scale, y_scale, z_scale, 1.0]))\n",
" scale=scale*(1.0/scale.max()) \n",
" scale[3,3]=factor\n",
" def short_proj(): \n",
" return np.dot(Axes3D.get_proj(ax), scale) \n",
" return short_proj \n",
"\n",
"def style3d(ax, x_scale, y_scale, z_scale, factor=0.62):\n",
" plt.gca().patch.set_facecolor('white')\n",
" ax.w_xaxis.set_pane_color((0, 0, 0, 0))\n",
" ax.w_yaxis.set_pane_color((0, 0, 0, 0))\n",
" ax.w_zaxis.set_pane_color((0, 0, 0, 0))\n",
" ax.get_proj = scale_3d(ax, x_scale, y_scale, z_scale, factor)\n",
"\n",
"def kdeg(x, X, h):\n",
" \"\"\"\n",
" KDE under a gaussian kernel\n",
"\n",
" Parameters\n",
" ----------\n",
" x: array(eval, D)\n",
" X: array(obs, D)\n",
" h: float\n",
"\n",
" Returns\n",
" -------\n",
" array(eval):\n",
" KDE around the observed values\n",
" \"\"\"\n",
" N, D = X.shape\n",
" nden, _ = x.shape\n",
"\n",
" Xhat = X.reshape(D, 1, N)\n",
" xhat = x.reshape(D, nden, 1)\n",
" u = xhat - Xhat\n",
" u = np.linalg.norm(u, ord=2, axis=0) ** 2 / (2 * h ** 2)\n",
" px = np.exp(-u).sum(axis=1) / (N * h * np.sqrt(2 * np.pi))\n",
" return px\n",
"\n",
"\n",
" # Plot target dataset\n",
"dict_figures = {}\n",
"color_dict = {0: \"tab:green\", 1: \"tab:red\", 2: \"tab:blue\"}\n",
"fig, ax = plt.subplots()\n",
"color_states_org = [color_dict[int(state)] for state in dstates]\n",
"ax.scatter(*cstates[:, [0, 2]].T, c=\"none\", edgecolors=color_states_org, s=10)\n",
"ax.scatter(*emissions[:, [0, 2]].T, s=5, c=\"black\", alpha=0.6)\n",
"ax.set_title(\"Data\")\n",
"dict_figures[\"rbpf-maneuver-data\"] = fig\n",
"\n",
"# Plot filtered dataset\n",
"fig, ax = plt.subplots()\n",
"rbpf_mse = ((post_mean - cstates)[:, [0, 2]] ** 2).mean(axis=0).sum()\n",
"color_states_est = [color_dict[int(state)] for state in np.array(est_dstates)]\n",
"ax.scatter(*post_mean[:, [0, 2]].T, c=\"none\", edgecolors=color_states_est, s=10)\n",
"ax.set_title(f\"RBPF MSE: {rbpf_mse:.2f}\")\n",
"dict_figures[\"rbpf-maneuver-trace\"] = fig\n",
"\n",
"# Plot belief state of discrete system\n",
"rbpf_error_rate = (dstates != est_dstates).mean()\n",
"fig, ax = plt.subplots(figsize=(2.5, 5))\n",
"sns.heatmap(p_est, cmap=\"viridis\", cbar=False)\n",
"plt.title(f\"RBPF, error rate: {rbpf_error_rate:0.3}\")\n",
"dict_figures[\"rbpf-maneuver-discrete-belief\"] = fig\n",
"\n",
"# Plot ground truth and MAP estimate\n",
"ohe = OneHotEncoder(sparse=False)\n",
"latent_hmap = ohe.fit_transform(dstates[:, None])\n",
"latent_hmap_est = ohe.fit_transform(p_est.argmax(axis=1)[:, None])\n",
"\n",
"fig, ax = plt.subplots(figsize=(2.5, 5))\n",
"sns.heatmap(latent_hmap, cmap=\"viridis\", cbar=False, ax=ax)\n",
"ax.set_title(\"Data\")\n",
"dict_figures[\"rbpf-maneuver-discrete-ground-truth.pdf\"] = fig\n",
"\n",
"fig, ax = plt.subplots(figsize=(2.5, 5))\n",
"sns.heatmap(latent_hmap_est, cmap=\"viridis\", cbar=False, ax=ax)\n",
"ax.set_title(f\"MAP (error rate: {rbpf_error_rate:0.4f})\")\n",
"dict_figures[\"rbpf-maneuver-discrete-map\"] = fig\n",
"\n",
"# Plot belief for state space\n",
"dims = [0, 2]\n",
"for dim in dims:\n",
" fig = plt.figure()\n",
" ax = plt.axes(projection=\"3d\")\n",
" plot_3d_belief_state(filtered_means, dim, ax, h=1.1)\n",
" # pml.savefig(f\"rbpf-maneuver-belief-states-dim{dim}.pdf\", pad_inches=0, bbox_inches=\"tight\")\n",
" dict_figures[f\"rbpf-maneuver-belief-states-dim{dim}.pdf\"] = fig\n",
"\n",
"\n",
"plt.rcParams[\"axes.spines.right\"] = False\n",
"plt.rcParams[\"axes.spines.top\"] = False\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 0 additions & 1 deletion dynamax/linear_gaussian_ssm/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ class TestFilteringAndSmoothing():
# Posteriors from full joint distribution
joint_means, joint_covs = joint_posterior_mvn(params, emissions)


## For sampling tests
lgssm, params = build_lgssm_for_sampling()

Expand Down
1 change: 0 additions & 1 deletion dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def e_step(

return (init_stats, dynamics_stats, emission_stats), posterior.marginal_loglik


def initialize_m_step_state(
self,
params: ParamsLGSSM,
Expand Down
2 changes: 2 additions & 0 deletions dynamax/slds/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from dynamax.slds.inference import DiscreteParamsSLDS, LGParamsSLDS, ParamsSLDS, rbpfilter, rbpfilter_optimal
from dynamax.slds.models import SLDS
Loading
Loading