Skip to content

Commit

Permalink
note on -derivative matching np.gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Feb 16, 2024
1 parent 1dcbed7 commit 6f7f877
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sup3r/postprocessing/file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def enforce_limits(features, data):

max = H5_ATTRS[dset_name].get('max', np.inf)
min = H5_ATTRS[dset_name].get('min', -np.inf)
logger.debug(f'Enforcing range of ({max}, {min} for "{fn}")')
logger.debug(f'Enforcing range of ({min}, {max} for "{fn}")')
maxs.append(max)
mins.append(min)

Expand Down
9 changes: 6 additions & 3 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ class MaterialDerivativeLoss(tf.keras.losses.Loss):
LOSS_METRIC = MeanAbsoluteError()

def _derivative(self, x, axis=1):
"""Custom derivative function for compatibility with tensorflow
"""Custom derivative function for compatibility with tensorflow.
NOTE: Matches np.gradient by using the central difference
approximation.
Parameters
----------
Expand Down Expand Up @@ -204,8 +207,8 @@ def __call__(self, x1, x2):
"""
hub_heights = x1.shape[-1] // 2

msg = (f'The {self.__class__} is meant to be used on spatiotemporal '
'data only. Received tensor(s) that are not 5D')
msg = (f'The {self.__class__.__name__} is meant to be used on '
'spatiotemporal data only. Received tensor(s) that are not 5D')
assert len(x1.shape) == 5 and len(x2.shape) == 5, msg

x1_div = tf.stack(
Expand Down

0 comments on commit 6f7f877

Please sign in to comment.