diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 0a8092e15..976f0a5ea 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -193,7 +193,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)] ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache]) - small_system_rhs = targets - fant_mean - ftcm + targets_ = torch.reshape(targets, (-1, ftcm.shape[-1])) + fant_mean_ = torch.reshape(fant_mean, (-1, ftcm.shape[-1])) + small_system_rhs = targets_ - fant_mean_ - ftcm small_system_rhs = small_system_rhs.unsqueeze(-1) # Schur complement of a spd matrix is guaranteed to be positive definite schur_cholesky = psd_safe_cholesky(schur_complement)