diff --git a/ssm_jax/hmm/models/base.py b/ssm_jax/hmm/models/base.py index b7e66771..db0a8221 100644 --- a/ssm_jax/hmm/models/base.py +++ b/ssm_jax/hmm/models/base.py @@ -154,7 +154,7 @@ def _single_expected_log_joint(emissions, posterior, **covariates): num_epochs=num_sgd_epochs_per_mstep) self.unconstrained_params = params - def fit_em(self, batch_emissions, num_iters=50, mstep_kwargs=dict(), **batch_covariates): + def fit_em(self, batch_emissions, num_iters=50, mstep_kwargs=dict(), verbose=True, **batch_covariates): """Fit this HMM with Expectation-Maximization (EM). Args: batch_emissions (_type_): _description_ @@ -173,7 +173,8 @@ def em_step(params): log_probs = [] params = self.unconstrained_params - for _ in trange(num_iters): + pbar = trange(num_iters) if verbose else range(num_iters) + for _ in pbar: params, lp = em_step(params) log_probs.append(lp)