Skip to content

Commit

Permalink
Merge pull request #22 from mackelab/numerical-stability
Browse files Browse the repository at this point in the history
Add epsilon after precision-factor-multiplication
  • Loading branch information
michaeldeistler authored Aug 24, 2022
2 parents 5ea3d5b + 76fe8ed commit db7d8e4
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions pyknos/mdn/mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
raise NotImplementedError

# Constant for numerical stability.
self._epsilon = 1e-2
self._epsilon = 1e-4

# Initialize mixture coefficients and precision factors sensibly.
if custom_initialization:
Expand Down Expand Up @@ -116,7 +116,7 @@ def get_mixture_components(

# Elements of diagonal of precision factor must be positive
# (recall precision factor A such that SIGMA^-1 = A^T A).
diagonal = F.softplus(unconstrained_diagonal) + self._epsilon
diagonal = F.softplus(unconstrained_diagonal)

# Create empty precision factor matrix, and fill with appropriate quantities.
precision_factors = torch.zeros(
Expand All @@ -139,6 +139,10 @@ def get_mixture_components(
precisions = torch.matmul(
torch.transpose(precision_factors, 2, 3), precision_factors
)
# Add epsilon to diagnonal for numerical stability.
precisions[
..., torch.arange(self._features), torch.arange(self._features)
] += self._epsilon

# The sum of the log diagonal of A is used in the likelihood calculation.
sumlogdiag = torch.sum(torch.log(diagonal), dim=-1)
Expand Down

0 comments on commit db7d8e4

Please sign in to comment.