From ff0870d40155f7224516e84b3be41f15e27c689b Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Sun, 21 May 2023 11:35:06 -0400 Subject: [PATCH] added optional jitter term (diagonal boost of cov) when sampling LGSSM posterior --- dynamax/linear_gaussian_ssm/inference.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index e866e491..2db9a260 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -493,7 +493,9 @@ def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]]=None, + jitter: Optional[Scalar]=0 + ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. @@ -502,6 +504,7 @@ def lgssm_posterior_sample( params: parameters. emissions: sequence of observations. inputs: optional sequence of inptus. + jitter: padding to add to the diagonal of the covariance matrix before sampling. Returns: Float[Array, "ntime state_dim"]: one sample of $z_{1:T}$ from the posterior distribution on latent states. @@ -527,6 +530,7 @@ def _step(carry, args): # Condition on next state smoothed_mean, smoothed_cov = _condition_on(filtered_mean, filtered_cov, F, B, b, Q, u, next_state) + smoothed_cov = smoothed_cov + jnp.eye(smoothed_cov.shape[-1]) * jitter state = MVN(smoothed_mean, smoothed_cov).sample(seed=key) return state, state