diff --git a/pyknos/mdn/mdn.py b/pyknos/mdn/mdn.py index b70112c..ee7de44 100644 --- a/pyknos/mdn/mdn.py +++ b/pyknos/mdn/mdn.py @@ -81,7 +81,7 @@ def __init__( raise NotImplementedError # Constant for numerical stability. - self._epsilon = 1e-2 + self._epsilon = 1e-4 # Initialize mixture coefficients and precision factors sensibly. if custom_initialization: @@ -116,7 +116,7 @@ def get_mixture_components( # Elements of diagonal of precision factor must be positive # (recall precision factor A such that SIGMA^-1 = A^T A). - diagonal = F.softplus(unconstrained_diagonal) + self._epsilon + diagonal = F.softplus(unconstrained_diagonal) # Create empty precision factor matrix, and fill with appropriate quantities. precision_factors = torch.zeros( @@ -139,6 +139,10 @@ def get_mixture_components( precisions = torch.matmul( torch.transpose(precision_factors, 2, 3), precision_factors ) + # Add epsilon to diagnonal for numerical stability. + precisions[ + ..., torch.arange(self._features), torch.arange(self._features) + ] += self._epsilon # The sum of the log diagonal of A is used in the likelihood calculation. sumlogdiag = torch.sum(torch.log(diagonal), dim=-1)