From 5afa9edc3117bc50936e8d84437cf3567ac54628 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 8 Jan 2025 11:10:40 -0700 Subject: [PATCH] fix: tensorboard issue with loss obs details --- sup3r/models/with_obs.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py index e8ce3e417..e15032e18 100644 --- a/sup3r/models/with_obs.py +++ b/sup3r/models/with_obs.py @@ -319,10 +319,10 @@ def get_single_grad( low_res, hi_res_true, **calc_loss_kwargs ) loss, loss_details = loss_out - if obs_data is not None: - loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + if not tf.reduce_any(tf.math.is_nan(loss_obs)): loss += loss_obs - loss_details['loss_obs'] = loss_obs + loss_details.update({'loss_obs': loss_obs}) grad = tape.gradient(loss, training_weights) return grad, loss_details @@ -345,8 +345,13 @@ def calc_loss_obs(self, obs_data, hi_res_gen): loss : tf.Tensor 0D tensor of observation loss """ - mask = tf.math.is_nan(obs_data) - return MeanAbsoluteError()( - obs_data[~mask], - hi_res_gen[..., : len(self.hr_out_features)][~mask], - ) + obs_loss = tf.constant(np.nan) + if obs_data is not None: + mask = tf.math.is_nan(obs_data) + masked_obs = obs_data[~mask] + if len(masked_obs) > 0: + obs_loss = MeanAbsoluteError()( + masked_obs, + hi_res_gen[..., : len(self.hr_out_features)][~mask], + ) + return obs_loss