Skip to content

Commit

Permalink
Merge branch 'bnb/training_with_obs' of github.com:NREL/sup3r into bn…
Browse files Browse the repository at this point in the history
…b/training_with_obs
  • Loading branch information
bnb32 committed Dec 29, 2024
2 parents a266482 + 2d99131 commit 3ab2175
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
1 change: 1 addition & 0 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None):

if key in loss_details:
saved_value = loss_details[key]
saved_value = 0 if np.isnan(saved_value) else saved_value
saved_value *= prior_n_obs
saved_value += batch_len * new_value
saved_value /= new_n_obs
Expand Down
2 changes: 1 addition & 1 deletion sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def train_epoch(

disc_th_low = np.min(disc_loss_bounds)
disc_th_high = np.max(disc_loss_bounds)
loss_details = {'n_obs': 0, 'train_loss_disc': 0}
loss_details = {'n_obs': 0, 'train_loss_disc': 0, 'train_loss_obs': 0}

only_gen = train_gen and not train_disc
only_disc = train_disc and not train_gen
Expand Down
18 changes: 13 additions & 5 deletions sup3r/preprocessing/batch_queues/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ def __init__(self, samplers, **kwargs):

@property
def queue_shape(self):
"""Shape of objects stored in the queue."""
queue_shapes = [(self.batch_size, *self.lr_shape)]
hr_mems = len(self.BATCH_MEMBERS) - 1
queue_shapes += [(self.batch_size, *self.hr_shape)] * hr_mems
return queue_shapes
"""Shape of objects stored in the queue. Optionally includes shape of
observation data which would be included in an extra content loss
term"""
obs_shape = (
*self.hr_shape[:-1],
len(self.containers[0].hr_out_features),
)
queue_shapes = [
(self.batch_size, *self.lr_shape),
(self.batch_size, *self.hr_shape),
(self.batch_size, *obs_shape),
]
return queue_shapes[: len(self.BATCH_MEMBERS)]

def check_enhancement_factors(self):
"""Make sure each DualSampler has the same enhancment factors and they
Expand Down
9 changes: 7 additions & 2 deletions sup3r/preprocessing/rasterizers/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
],
regrid_workers=1,
regrid_lr=True,
run_qa=False,
s_enhance=1,
t_enhance=1,
lr_cache_kwargs=None,
Expand All @@ -63,6 +64,9 @@ def __init__(
Flag to regrid the low-res data to the high-res grid. This will
take care of any minor inconsistencies in different projections.
Disable this if the grids are known to be the same.
run_qa : bool
Flag to run qa on the regridded low-res data. This will check for
NaNs and fill them if there are not too many.
s_enhance : int
Spatial enhancement factor
t_enhance : int
Expand Down Expand Up @@ -135,7 +139,8 @@ def __init__(
self.update_hr_data()
super().__init__(data=(self.lr_data, self.hr_data))

self.check_regridded_lr_data()
if run_qa:
self.check_regridded_lr_data()

if lr_cache_kwargs is not None:
Cacher(self.lr_data, lr_cache_kwargs)
Expand Down Expand Up @@ -205,7 +210,7 @@ def update_lr_data(self):
lr_coords_new = {
Dimension.LATITUDE: self.lr_lat_lon[..., 0],
Dimension.LONGITUDE: self.lr_lat_lon[..., 1],
Dimension.TIME: self.lr_data.indexes['time'][
Dimension.TIME: self.lr_data.indexes[Dimension.TIME][
: self.lr_required_shape[2]
],
}
Expand Down

0 comments on commit 3ab2175

Please sign in to comment.