Skip to content

Commit

Permalink
fix: tensorboard issue with loss obs details
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jan 19, 2025
1 parent d93a6b1 commit 5afa9ed
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions sup3r/models/with_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,10 @@ def get_single_grad(
low_res, hi_res_true, **calc_loss_kwargs
)
loss, loss_details = loss_out
if obs_data is not None:
loss_obs = self.calc_loss_obs(obs_data, hi_res_gen)
loss_obs = self.calc_loss_obs(obs_data, hi_res_gen)
if not tf.reduce_any(tf.math.is_nan(loss_obs)):
loss += loss_obs
loss_details['loss_obs'] = loss_obs
loss_details.update({'loss_obs': loss_obs})
grad = tape.gradient(loss, training_weights)
return grad, loss_details

Expand All @@ -345,8 +345,13 @@ def calc_loss_obs(self, obs_data, hi_res_gen):
loss : tf.Tensor
0D tensor of observation loss
"""
mask = tf.math.is_nan(obs_data)
return MeanAbsoluteError()(
obs_data[~mask],
hi_res_gen[..., : len(self.hr_out_features)][~mask],
)
obs_loss = tf.constant(np.nan)
if obs_data is not None:
mask = tf.math.is_nan(obs_data)
masked_obs = obs_data[~mask]
if len(masked_obs) > 0:
obs_loss = MeanAbsoluteError()(
masked_obs,
hi_res_gen[..., : len(self.hr_out_features)][~mask],
)
return obs_loss

0 comments on commit 5afa9ed

Please sign in to comment.