Skip to content

Commit

Permalink
Merge pull request #186 from NREL/bnb/dev
Browse files Browse the repository at this point in the history
Bnb/dev
  • Loading branch information
bnb32 authored Feb 16, 2024
2 parents e5b3ab6 + 6f7f877 commit d799491
Show file tree
Hide file tree
Showing 17 changed files with 312 additions and 468 deletions.
2 changes: 1 addition & 1 deletion sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def update_adversarial_weights(self, history, adaptive_update_fraction,

if update_frac != 1:
logger.debug(
f'New discriminator weight: {weight_gen_advers:.3f}')
f'New discriminator weight: {weight_gen_advers:.4e}')

return weight_gen_advers

Expand Down
32 changes: 19 additions & 13 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,7 @@ def __init__(self,
https://github.com/tensorflow/tensorflow/issues/51870
"""
self._input_handler_kwargs = input_handler_kwargs or {}
target = self._input_handler_kwargs.get('target', None)
grid_shape = self._input_handler_kwargs.get('shape', None)
raster_file = self._input_handler_kwargs.get('raster_file', None)
raster_index = self._input_handler_kwargs.get('raster_index', None)
temporal_slice = self._input_handler_kwargs.get(
'temporal_slice', slice(None, None, 1))
InputMixIn.__init__(self,
target=target,
shape=grid_shape,
raster_file=raster_file,
raster_index=raster_index,
temporal_slice=temporal_slice)

self.init_mixin()
self.file_paths = file_paths
self.model_kwargs = model_kwargs
self.fwp_chunk_shape = fwp_chunk_shape
Expand Down Expand Up @@ -808,6 +796,23 @@ def __init__(self,

self.preflight()

def init_mixin(self):
"""Initialize InputMixIn class"""
target = self._input_handler_kwargs.get('target', None)
grid_shape = self._input_handler_kwargs.get('shape', None)
raster_file = self._input_handler_kwargs.get('raster_file', None)
raster_index = self._input_handler_kwargs.get('raster_index', None)
temporal_slice = self._input_handler_kwargs.get(
'temporal_slice', slice(None, None, 1))
res_kwargs = self._input_handler_kwargs.get('res_kwargs', None)
InputMixIn.__init__(self,
target=target,
shape=grid_shape,
raster_file=raster_file,
raster_index=raster_index,
temporal_slice=temporal_slice,
res_kwargs=res_kwargs)

def preflight(self):
"""Prelight path name formatting and sanity checks"""

Expand Down Expand Up @@ -1257,6 +1262,7 @@ def meta(self):
meta_data = {
'chunk_meta': self.chunk_specific_meta,
'gan_meta': self.model.meta,
'gan_params': self.model.model_params,
'model_kwargs': self.model_kwargs,
'model_class': self.model_class,
'spatial_enhance': int(self.s_enhance),
Expand Down
15 changes: 13 additions & 2 deletions sup3r/postprocessing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, file_paths):
"""
if not isinstance(file_paths, list):
file_paths = glob.glob(file_paths)
self.file_paths = file_paths
self.flist = sorted(file_paths)
self.data = None
self.file_attrs = {}
Expand Down Expand Up @@ -153,8 +154,7 @@ def collect(
os.remove(out_file)

if not os.path.exists(out_file):
res_kwargs = (res_kwargs
or {"concat_dim": "Time", "combine": "nested"})
res_kwargs = res_kwargs or {}
out = xr.open_mfdataset(collector.flist, **res_kwargs)
features = [feat for feat in out if feat in features
or feat.lower() in features]
Expand All @@ -174,6 +174,17 @@ def collect(

logger.info('Finished file collection.')

def group_spatial_chunks(self):
"""Group same spatial chunks together so each chunk has same spatial
footprint but different times"""
chunks = {}
for file in self.flist:
s_chunk = file.split('_')[0]
dirname = os.path.dirname(file)
s_file = os.path.join(dirname, f's_{s_chunk}.nc')
chunks[s_file] = [*chunks.get(s_file, []), s_file]
return chunks


class CollectorH5(BaseCollector):
"""Sup3r H5 file collection framework"""
Expand Down
22 changes: 9 additions & 13 deletions sup3r/postprocessing/file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def enforce_limits(features, data):

max = H5_ATTRS[dset_name].get('max', np.inf)
min = H5_ATTRS[dset_name].get('min', -np.inf)
logger.debug(f'Enforcing range of ({max}, {min} for "{fn}")')
logger.debug(f'Enforcing range of ({min}, {max} for "{fn}")')
maxs.append(max)
mins.append(min)

Expand Down Expand Up @@ -576,13 +576,13 @@ def _write_output(cls, data, features, lat_lon, times, out_file,
List of coordinate indices used to label each lat lon pair and to
help with spatial chunk data collection
"""
coords = {'Times': (['Time'], [str(t).encode('utf-8') for t in times]),
'XLAT': (['south_north', 'east_west'], lat_lon[..., 0]),
'XLONG': (['south_north', 'east_west'], lat_lon[..., 1])}
coords = {'Time': [str(t).encode('utf-8') for t in times],
'south_north': lat_lon[:, 0, 0].astype(np.float32),
'west_east': lat_lon[0, :, 1].astype(np.float32)}

data_vars = {}
for i, f in enumerate(features):
data_vars[f] = (['Time', 'south_north', 'east_west'],
data_vars[f] = (['Time', 'south_north', 'west_east'],
np.transpose(data[..., i], (2, 0, 1)))

attrs = {}
Expand Down Expand Up @@ -631,11 +631,9 @@ def get_renamed_features(cls, features):
List of renamed features u/v -> windspeed/winddirection for each
height
"""
heights = []
heights = [Feature.get_height(f) for f in features
if re.match('U_(.*?)m'.lower(), f.lower())]
renamed_features = features.copy()
for f in features:
if re.match('U_(.*?)m'.lower(), f.lower()):
heights.append(Feature.get_height(f))

for height in heights:
u_idx = features.index(f'U_{height}m')
Expand Down Expand Up @@ -666,10 +664,8 @@ def invert_uv_features(cls, data, features, lat_lon, max_workers=None):
will be estimated based on memory limits.
"""

heights = []
for f in features:
if re.match('U_(.*?)m'.lower(), f.lower()):
heights.append(Feature.get_height(f))
heights = [Feature.get_height(f) for f in features if
re.match('U_(.*?)m'.lower(), f.lower())]
if heights:
logger.info('Converting u/v to windspeed/winddirection for h5'
' output')
Expand Down
Loading

0 comments on commit d799491

Please sign in to comment.