Skip to content

Commit

Permalink
Merge pull request #324 from calebweinreb/lgssm_parallel_inference_wi…
Browse files Browse the repository at this point in the history
…th_bias

Refactor of LGSSM parallel inference
  • Loading branch information
slinderman authored May 24, 2023
2 parents d7f283e + effaba3 commit fa4723f
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 140 deletions.
1 change: 1 addition & 0 deletions dynamax/linear_gaussian_ssm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@

from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_filter as parallel_lgssm_filter
from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_smoother as parallel_lgssm_smoother
from dynamax.linear_gaussian_ssm.parallel_inference import lgssm_posterior_sample as parallel_lgssm_posterior_sample

from dynamax.linear_gaussian_ssm.models import LinearGaussianConjugateSSM, LinearGaussianSSM
299 changes: 212 additions & 87 deletions dynamax/linear_gaussian_ssm/parallel_inference.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,140 @@
# Parallel filtering and smoothing for a lgssm.
# This implementation is adapted from the work of Adrien Correnflos in,
# https://github.com/EEA-sensors/sequential-parallelization-examples/
'''
Parallel filtering and smoothing for a lgssm.
This implementation is adapted from the work of Adrien Correnflos:
https://github.com/EEA-sensors/sequential-parallelization-examples/
Note that in the original implementation, the initial state distribution
applies to t=0, and the first emission occurs at time `t=1` (i.e. after
the initial state has been transformed by the dynamics), whereas here,
the first emission occurs at time `t=0` and is produced directly by the
untransformed initial state (see below).
Sarkka et al.
F₀,Q₀ F₁,Q₁ F₂,Q₂
Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
| | |
| H₁,R₁ | H₂,R₂ | H₃,R₃
| | |
Y₁ Y₂ Y₃
Dynamax
F₀,Q₀ F₁,Q₁ F₂,Q₂
Z₀ ─────────── Z₁ ─────────── Z₂ ─────────── Z₃ ─────...
| | | |
| H₀,R₀ | H₁,R₁ | H₂,R₂ | H₃,R₃
| | | |
Y₀ Y₁ Y₂ Y₃
'''

import jax.numpy as jnp
import jax.scipy as jsc
from jax import vmap, lax
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
from jaxtyping import Array, Float
from typing import NamedTuple
from dynamax.types import PRNGKey
from functools import partial

from dynamax.utils.utils import psd_solve
from jax.scipy.linalg import cho_solve, cho_factor
from dynamax.utils.utils import symmetrize
from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM


def _get_params(x, dim, t):
if callable(x):
return x(t)
elif x.ndim == dim + 1:
return x[t]
else:
return x

#---------------------------------------------------------------------------#
# Filtering #
#---------------------------------------------------------------------------#

def _make_associative_filtering_elements(params, emissions):
class FilterMessage(NamedTuple):
"""
Filtering associative scan elements.
Attributes:
A: P(z_j | y_{i:j}, z_{i-1}) weights.
b: P(z_j | y_{i:j}, z_{i-1}) bias.
C: P(z_j | y_{i:j}, z_{i-1}) covariance.
J: P(z_{i-1} | y_{i:j}) covariance.
eta: P(z_{i-1} | y_{i:j}) mean.
"""
A: Float[Array, "ntime state_dim state_dim"]
b: Float[Array, "ntime state_dim"]
C: Float[Array, "ntime state_dim state_dim"]
J: Float[Array, "ntime state_dim state_dim"]
eta: Float[Array, "ntime state_dim"]
logZ: Float[Array, "ntime"]


def _initialize_filtering_messages(params, emissions):
"""Preprocess observations to construct input for filtering assocative scan."""

def _first_filtering_element(params, y):
F = _get_params(params.dynamics.weights, 2, 0)
def _first_message(params, y):
H = _get_params(params.emissions.weights, 2, 0)
Q = _get_params(params.dynamics.cov, 2, 0)
R = _get_params(params.emissions.cov, 2, 0)
d = _get_params(params.emissions.bias, 1, 0)
m = params.initial.mean
P = params.initial.cov

S = H @ Q @ H.T + R
CF, low = jsc.linalg.cho_factor(S)
S = H @ P @ H.T + R
CF, low = cho_factor(S)
K = cho_solve((CF, low), H @ P).T

m1 = params.initial.mean
P1 = params.initial.cov
S1 = H @ P1 @ H.T + R
K1 = psd_solve(S1, H @ P1).T

A = jnp.zeros_like(F)
b = m1 + K1 @ (y - H @ m1)
C = P1 - K1 @ S1 @ K1.T

eta = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), y)
J = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), H @ F)

logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=H @ P1 @ H.T + R).log_prob(y)
A = jnp.zeros_like(P)
b = m + K @ (y - H @ m - d)
C = symmetrize(P - K @ S @ K.T)
eta = jnp.zeros_like(b)
J = jnp.eye(len(b))

logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=H @ P @ H.T + R).log_prob(y)
return A, b, C, J, eta, logZ


def _generic_filtering_element(params, y, t):
@partial(vmap, in_axes=(None, 0, 0))
def _generic_message(params, y, t):
F = _get_params(params.dynamics.weights, 2, t)
H = _get_params(params.emissions.weights, 2, t+1)
Q = _get_params(params.dynamics.cov, 2, t)
b = _get_params(params.dynamics.bias, 1, t)
H = _get_params(params.emissions.weights, 2, t+1)
R = _get_params(params.emissions.cov, 2, t+1)
d = _get_params(params.emissions.bias, 1, t+1)

S = H @ Q @ H.T + R
CF, low = jsc.linalg.cho_factor(S)
K = jsc.linalg.cho_solve((CF, low), H @ Q).T
A = F - K @ H @ F
b = K @ y
C = Q - K @ H @ Q
CF, low = cho_factor(S)
K = cho_solve((CF, low), H @ Q).T

eta = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), y)
J = F.T @ H.T @ jsc.linalg.cho_solve((CF, low), H @ F)
eta = F.T @ H.T @ cho_solve((CF, low), y - H @ b - d)
J = symmetrize(F.T @ H.T @ cho_solve((CF, low), H @ F))

logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=S).log_prob(y)
A = F - K @ H @ F
b = b + K @ (y - H @ b - d)
C = symmetrize(Q - K @ H @ Q)

logZ = -MVN(loc=jnp.zeros_like(y), covariance_matrix=S).log_prob(y)
return A, b, C, J, eta, logZ

first_elems = _first_filtering_element(params, emissions[0])
generic_elems = vmap(_generic_filtering_element, (None, 0, 0))(params, emissions[1:], jnp.arange(len(emissions)-1))
combined_elems = tuple(jnp.concatenate((first_elm[None,...], gen_elm))
for first_elm, gen_elm in zip(first_elems, generic_elems))
return combined_elems

A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0])
At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1))

return FilterMessage(
A=jnp.concatenate([A0[None], At]),
b=jnp.concatenate([b0[None], bt]),
C=jnp.concatenate([C0[None], Ct]),
J=jnp.concatenate([J0[None], Jt]),
eta=jnp.concatenate([eta0[None], etat]),
logZ=jnp.concatenate([logZ0[None], logZt])
)



def lgssm_filter(
params: ParamsLGSSM,
Expand All @@ -83,71 +146,81 @@ def lgssm_filter(
Note: This function does not yet handle `inputs` to the system.
"""
#TODO: Add input handling.
initial_elements = _make_associative_filtering_elements(params, emissions)

@vmap
def filtering_operator(elem1, elem2):
def _operator(elem1, elem2):
A1, b1, C1, J1, eta1, logZ1 = elem1
A2, b2, C2, J2, eta2, logZ2 = elem2
dim = A1.shape[0]
I = jnp.eye(dim)
I = jnp.eye(A1.shape[0])

I_C1J2 = I + C1 @ J2
temp = jsc.linalg.solve(I_C1J2.T, A2.T).T
temp = jnp.linalg.solve(I_C1J2.T, A2.T).T
A = temp @ A1
b = temp @ (b1 + C1 @ eta2) + b2
C = temp @ C1 @ A2.T + C2
C = symmetrize(temp @ C1 @ A2.T + C2)

I_J2C1 = I + J2 @ C1
temp = jsc.linalg.solve(I_J2C1.T, A1).T

temp = jnp.linalg.solve(I_J2C1.T, A1).T
eta = temp @ (eta2 - J2 @ b1) + eta1
J = temp @ J2 @ A1 + J1

# mu = jsc.linalg.solve(J2, eta2)
# t2 = - eta2 @ mu + (b1 - mu) @ jsc.linalg.solve(I_J2C1, (J2 @ b1 - eta2))
J = symmetrize(temp @ J2 @ A1 + J1)

mu = jnp.linalg.solve(C1, b1)
t1 = (b1 @ mu - (eta2 + mu) @ jnp.linalg.solve(I_C1J2, C1 @ eta2 + b1))

logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1)
return FilterMessage(A, b, C, J, eta, logZ)

initial_messages = _initialize_filtering_messages(params, emissions)
final_messages = lax.associative_scan(_operator, initial_messages)

return PosteriorGSSMFiltered(
filtered_means=final_messages.b,
filtered_covariances=final_messages.C,
marginal_loglik=-final_messages.logZ[-1])

return A, b, C, J, eta, logZ

_, filtered_means, filtered_covs, _, _, logZ = lax.associative_scan(
filtering_operator, initial_elements
)
#---------------------------------------------------------------------------#
# Smoothing #
#---------------------------------------------------------------------------#

return PosteriorGSSMFiltered(marginal_loglik=-logZ[-1],
filtered_means=filtered_means, filtered_covariances=filtered_covs)
class SmoothMessage(NamedTuple):
"""
Smoothing associative scan elements.
Attributes:
E: P(z_i | y_{1:j}, z_{j+1}) weights.
g: P(z_i | y_{1:j}, z_{j+1}) bias.
L: P(z_i | y_{1:j}, z_{j+1}) covariance.
"""
E: Float[Array, "ntime state_dim state_dim"]
g: Float[Array, "ntime state_dim"]
L: Float[Array, "ntime state_dim state_dim"]


def _make_associative_smoothing_elements(params, filtered_means, filtered_covariances):
def _initialize_smoothing_messages(params, filtered_means, filtered_covariances):
"""Preprocess filtering output to construct input for smoothing assocative scan."""

def _last_smoothing_element(m, P):
def _last_message(m, P):
return jnp.zeros_like(P), m, P

def _generic_smoothing_element(params, m, P, t):
@partial(vmap, in_axes=(None, 0, 0, 0))
def _generic_message(params, m, P, t):
F = _get_params(params.dynamics.weights, 2, t)
Q = _get_params(params.dynamics.cov, 2, t)
b = _get_params(params.dynamics.bias, 1, t)

Pp = F @ P @ F.T + Q

E = psd_solve(Pp, F @ P).T
g = m - E @ F @ m
L = P - E @ Pp @ E.T
CF, low = cho_factor(F @ P @ F.T + Q)
E = cho_solve((CF, low), F @ P).T
g = m - E @ (F @ m + b)
L = symmetrize(P - E @ F @ P)
return E, g, L

last_elems = _last_smoothing_element(filtered_means[-1], filtered_covariances[-1])
generic_elems = vmap(_generic_smoothing_element, (None, 0, 0, 0))(
params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_covariances)-1)
)
combined_elems = tuple(jnp.append(gen_elm, last_elm[None,:], axis=0)
for gen_elm, last_elm in zip(generic_elems, last_elems))
return combined_elems

En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1])
Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1))

return SmoothMessage(
E=jnp.concatenate([Et, En[None]]),
g=jnp.concatenate([gt, gn[None]]),
L=jnp.concatenate([Lt, Ln[None]])
)


def lgssm_smoother(
Expand All @@ -163,26 +236,78 @@ def lgssm_smoother(
filtered_posterior = lgssm_filter(params, emissions)
filtered_means = filtered_posterior.filtered_means
filtered_covs = filtered_posterior.filtered_covariances
initial_elements = _make_associative_smoothing_elements(params, filtered_means, filtered_covs)


@vmap
def smoothing_operator(elem1, elem2):
def _operator(elem1, elem2):
E1, g1, L1 = elem1
E2, g2, L2 = elem2

E = E2 @ E1
g = E2 @ g1 + g2
L = E2 @ L1 @ E2.T + L2

L = symmetrize(E2 @ L1 @ E2.T + L2)
return E, g, L

_, smoothed_means, smoothed_covs, *_ = lax.associative_scan(
smoothing_operator, initial_elements, reverse=True
)
initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)

return PosteriorGSSMSmoothed(
marginal_loglik=filtered_posterior.marginal_loglik,
filtered_means=filtered_means,
filtered_covariances=filtered_covs,
smoothed_means=smoothed_means,
smoothed_covariances=smoothed_covs
smoothed_means=final_messages.g,
smoothed_covariances=final_messages.L
)


#---------------------------------------------------------------------------#
# Sampling #
#---------------------------------------------------------------------------#

class SampleMessage(NamedTuple):
"""
Sampling associative scan elements.
Attributes:
E: z_i ~ z_{j+1} weights.
h: z_i ~ z_{j+1} bias.
"""
E: Float[Array, "ntime state_dim state_dim"]
h: Float[Array, "ntime state_dim"]


def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances):
"""A parallel version of the lgssm sampling algorithm.
Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`,
the parallel sampling messages are `(E_i,h_i)` where `h_i ~ N(g_i, L_i)`.
"""
E, g, L = _initialize_smoothing_messages(params, filtered_means, filtered_covariances)
return SampleMessage(E=E, h=MVN(g, L).sample(seed=key))


def lgssm_posterior_sample(
key: PRNGKey,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
) -> Float[Array, "ntime state_dim"]:
"""A parallel version of the lgssm sampling algorithm.
See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.
Note: This function does not yet handle `inputs` to the system.
"""
filtered_posterior = lgssm_filter(params, emissions)
filtered_means = filtered_posterior.filtered_means
filtered_covs = filtered_posterior.filtered_covariances

@vmap
def _operator(elem1, elem2):
E1, h1 = elem1
E2, h2 = elem2

E = E2 @ E1
h = E2 @ h1 + h2
return E, h

initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs)
_, samples = lax.associative_scan(_operator, initial_messages, reverse=True)
return samples
Loading

0 comments on commit fa4723f

Please sign in to comment.