From 02ac2c96c0d7e2b60ef774ae74053a996ea835d9 Mon Sep 17 00:00:00 2001 From: turquoisedragon2926 Date: Thu, 12 Sep 2024 15:43:27 -0400 Subject: [PATCH 1/2] Reshape Targets and Mean for RHS in Cholesky solver --- gpytorch/models/exact_prediction_strategies.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 0a8092e15..abd868ba7 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -193,7 +193,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)] ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache]) - small_system_rhs = targets - fant_mean - ftcm + targets_ = torch.reshape(targets, ftcm.shape) + fant_mean_ = torch.reshape(fant_mean, ftcm.shape) + small_system_rhs = targets_ - fant_mean_ - ftcm small_system_rhs = small_system_rhs.unsqueeze(-1) # Schur complement of a spd matrix is guaranteed to be positive definite schur_cholesky = psd_safe_cholesky(schur_complement) From f5216c358b74afbe4f97ccab1339244135edbe41 Mon Sep 17 00:00:00 2001 From: turquoisedragon2926 Date: Thu, 12 Sep 2024 16:56:21 -0400 Subject: [PATCH 2/2] Handle Batched Computation --- gpytorch/models/exact_prediction_strategies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index abd868ba7..976f0a5ea 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -193,8 +193,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)] ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache]) - targets_ = torch.reshape(targets, ftcm.shape) - fant_mean_ = torch.reshape(fant_mean, ftcm.shape) + targets_ = torch.reshape(targets, (-1, ftcm.shape[-1])) + fant_mean_ = torch.reshape(fant_mean, (-1, ftcm.shape[-1])) small_system_rhs = targets_ - fant_mean_ - ftcm small_system_rhs = small_system_rhs.unsqueeze(-1) # Schur complement of a spd matrix is guaranteed to be positive definite