Skip to content

Commit

Permalink
Merge pull request #321 from calebweinreb/robust_lgssm_posterior_sample
Browse files Browse the repository at this point in the history
Robust lgssm posterior sample
  • Loading branch information
slinderman authored May 21, 2023
2 parents f40a2fd + ff0870d commit d7f283e
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ def lgssm_posterior_sample(
key: PRNGKey,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
inputs: Optional[Float[Array, "ntime input_dim"]]=None,
jitter: Optional[Scalar]=0

) -> Float[Array, "ntime state_dim"]:
r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$.
Expand All @@ -502,6 +504,7 @@ def lgssm_posterior_sample(
params: parameters.
emissions: sequence of observations.
inputs: optional sequence of inptus.
jitter: padding to add to the diagonal of the covariance matrix before sampling.
Returns:
Float[Array, "ntime state_dim"]: one sample of $z_{1:T}$ from the posterior distribution on latent states.
Expand All @@ -527,6 +530,7 @@ def _step(carry, args):

# Condition on next state
smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, F, B, b, Q, u, next_state)
smoothed_cov = smoothed_cov + jnp.eye(smoothed_cov.shape[-1]) * jitter
state = MVN(smoothed_mean, smoothed_cov).sample(seed=key)
return state, state

Expand Down

0 comments on commit d7f283e

Please sign in to comment.