From 2d99131205f9e47f54aa73ab955c1da969b05e8a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 29 Dec 2024 07:43:59 -0700 Subject: [PATCH] Optional run_qa flag in ``DualRasterizer``. Queue shape fix for queues with obs data --- sup3r/models/abstract.py | 1 + sup3r/models/base.py | 2 +- sup3r/preprocessing/batch_queues/dual.py | 18 +++++++++++++----- sup3r/preprocessing/rasterizers/dual.py | 9 +++++++-- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 506f85b38..eed051547 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -538,6 +538,7 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): if key in loss_details: saved_value = loss_details[key] + saved_value = 0 if np.isnan(saved_value) else saved_value saved_value *= prior_n_obs saved_value += batch_len * new_value saved_value /= new_n_obs diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 7cbe4af98..728576676 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -985,7 +985,7 @@ def train_epoch( disc_th_low = np.min(disc_loss_bounds) disc_th_high = np.max(disc_loss_bounds) - loss_details = {'n_obs': 0, 'train_loss_disc': 0} + loss_details = {'n_obs': 0, 'train_loss_disc': 0, 'train_loss_obs': 0} only_gen = train_gen and not train_disc only_disc = train_disc and not train_gen diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index b7e26bf94..56b2b08d4 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -28,11 +28,19 @@ def __init__(self, samplers, **kwargs): @property def queue_shape(self): - """Shape of objects stored in the queue.""" - queue_shapes = [(self.batch_size, *self.lr_shape)] - hr_mems = len(self.BATCH_MEMBERS) - 1 - queue_shapes += [(self.batch_size, *self.hr_shape)] * hr_mems - return queue_shapes + """Shape of objects stored in the queue. Optionally includes shape of + observation data which would be included in an extra content loss + term""" + obs_shape = ( + *self.hr_shape[:-1], + len(self.containers[0].hr_out_features), + ) + queue_shapes = [ + (self.batch_size, *self.lr_shape), + (self.batch_size, *self.hr_shape), + (self.batch_size, *obs_shape), + ] + return queue_shapes[: len(self.BATCH_MEMBERS)] def check_enhancement_factors(self): """Make sure each DualSampler has the same enhancment factors and they diff --git a/sup3r/preprocessing/rasterizers/dual.py b/sup3r/preprocessing/rasterizers/dual.py index 47706aa4e..a70f01b08 100644 --- a/sup3r/preprocessing/rasterizers/dual.py +++ b/sup3r/preprocessing/rasterizers/dual.py @@ -43,6 +43,7 @@ def __init__( ], regrid_workers=1, regrid_lr=True, + run_qa=False, s_enhance=1, t_enhance=1, lr_cache_kwargs=None, @@ -63,6 +64,9 @@ def __init__( Flag to regrid the low-res data to the high-res grid. This will take care of any minor inconsistencies in different projections. Disable this if the grids are known to be the same. + run_qa : bool + Flag to run qa on the regridded low-res data. This will check for + NaNs and fill them if there are not too many. s_enhance : int Spatial enhancement factor t_enhance : int @@ -135,7 +139,8 @@ def __init__( self.update_hr_data() super().__init__(data=(self.lr_data, self.hr_data)) - self.check_regridded_lr_data() + if run_qa: + self.check_regridded_lr_data() if lr_cache_kwargs is not None: Cacher(self.lr_data, lr_cache_kwargs) @@ -205,7 +210,7 @@ def update_lr_data(self): lr_coords_new = { Dimension.LATITUDE: self.lr_lat_lon[..., 0], Dimension.LONGITUDE: self.lr_lat_lon[..., 1], - Dimension.TIME: self.lr_data.indexes['time'][ + Dimension.TIME: self.lr_data.indexes[Dimension.TIME][ : self.lr_required_shape[2] ], }