diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index f5ce52862..b9f2447e1 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -255,11 +255,8 @@ def _combine_fwp_input(self, low_res, exogenous_data=None): and exogenous_data is not None): exogenous_data = ExoData(exogenous_data) - training_features = ([] if self.training_features is None - else self.training_features) - fnum_diff = len(training_features) - low_res.shape[-1] - exo_feats = ([] if fnum_diff <= 0 - else self.training_features[-fnum_diff:]) + fnum_diff = len(self.lr_features) - low_res.shape[-1] + exo_feats = [] if fnum_diff <= 0 else self.lr_features[-fnum_diff:] msg = ('Provided exogenous_data is missing some required features ' f'({exo_feats})') assert all(feature in exogenous_data for feature in exo_feats), msg @@ -269,6 +266,7 @@ def _combine_fwp_input(self, low_res, exogenous_data=None): feature, 'input') if exo_input is not None: low_res = np.concatenate((low_res, exo_input), axis=-1) + return low_res def _combine_fwp_output(self, hi_res, exogenous_data=None): @@ -302,11 +300,9 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): and exogenous_data is not None): exogenous_data = ExoData(exogenous_data) - output_features = ([] if self.output_features is None - else self.output_features) - fnum_diff = len(output_features) - hi_res.shape[-1] + fnum_diff = len(self.hr_out_features) - hi_res.shape[-1] exo_feats = ([] if fnum_diff <= 0 - else self.output_features[-fnum_diff:]) + else self.hr_out_features[-fnum_diff:]) msg = ('Provided exogenous_data is missing some required features ' f'({exo_feats})') assert all(feature in exogenous_data for feature in exo_feats), msg @@ -318,6 +314,7 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): hi_res = np.concatenate((hi_res, exo_output), axis=-1) return hi_res + @tf.function def _combine_loss_input(self, high_res_true, high_res_gen): """Combine exogenous feature data from high_res_true with high_res_gen for loss calculation @@ -336,62 +333,48 @@ def _combine_loss_input(self, high_res_true, high_res_gen): Same as input with exogenous data combined with high_res input """ if high_res_true.shape[-1] > high_res_gen.shape[-1]: - for feature in self.exogenous_features: - f_idx = self.hr_features.index(feature) + for feature in self.hr_exo_features: + f_idx = self.hr_exo_features.index(feature) + f_idx += len(self.hr_out_features) exo_data = high_res_true[..., f_idx: f_idx + 1] high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) return high_res_gen - @property - def exogenous_features(self): - """Get list of exogenous filter names the model uses. If the model has - N concat or add layers this list will be the last N features in the - training features list. The ordering is assumed to be the same as the - order of concat or add layers. If training features is [..., topo, - sza], and the model has 2 concat or add layers, exo features will be - [topo, sza]. Topo will then be used in the first concat layer and sza - will be used in the second""" - # pylint: disable=E1101 - features = [] - if hasattr(self, '_gen'): - for layer in self._gen.layers: - if isinstance(layer, (Sup3rAdder, Sup3rConcat)): - features.append(layer.name) - return features - @property @abstractmethod def meta(self): """Get meta data dictionary that defines how the model was created""" @property - def training_features(self): - """Get the list of input feature names that the generative model was - trained on.""" - return self.meta.get('training_features', None) + def lr_features(self): + """Get a list of low-resolution features input to the generative model. + This includes low-resolution features that might be supplied + exogenously at inference time but that were in the low-res batches + during training""" + return self.meta.get('lr_features', []) @property - def train_only_features(self): - """Get the list of feature names used only for training (expected as - input but not included in output).""" - return self.meta.get('train_only_features', None) + def hr_out_features(self): + """Get the list of high-resolution output feature names that the + generative model outputs.""" + return self.meta.get('hr_out_features', []) @property - def hr_features(self): - """Get the list of features stored in batch.high_res. This is the same - as training_features but without train_only_features. This is used to - select the correct high res exogenous data.""" - hr_features = self.training_features - if self.train_only_features is not None: - hr_features = [f for f in self.training_features - if f not in self.train_only_features] - return hr_features - - @property - def output_features(self): - """Get the list of output feature names that the generative model - outputs and that the discriminator predicts on.""" - return self.meta.get('output_features', None) + def hr_exo_features(self): + """Get list of high-resolution exogenous filter names the model uses. + If the model has N concat or add layers this list will be the last N + features in the training features list. The ordering is assumed to be + the same as the order of concat or add layers. If training features is + [..., topo, sza], and the model has 2 concat or add layers, exo + features will be [topo, sza]. Topo will then be used in the first + concat layer and sza will be used in the second""" + # pylint: disable=E1101 + features = [] + if hasattr(self, '_gen'): + for layer in self._gen.layers: + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + features.append(layer.name) + return features @property def smoothing(self): @@ -403,7 +386,7 @@ def smoothing(self): def smoothed_features(self): """Get the list of smoothed input feature names that the generative model was trained on.""" - return self.meta.get('smoothed_features', None) + return self.meta.get('smoothed_features', []) @property def model_params(self): @@ -427,38 +410,6 @@ def version_record(self): """ return VERSION_RECORD - def _check_exo_features(self, **kwargs): - """Make sure exogenous features have the correct ordering and are - included in training_features - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' - - Returns - ------- - kwargs : dict - Same as input but with exogenous_features removed from output - features - """ - if 'output_features' not in kwargs: - return kwargs - - output_features = kwargs['output_features'] - msg = (f'Last {len(self.exogenous_features)} output features from the ' - f'data handler must be {self.exogenous_features} ' - 'to train the Exo model, but received output features: {}'. - format(output_features)) - exo_features = ([] if len(self.exogenous_features) == 0 - else output_features[-len(self.exogenous_features):]) - assert exo_features == self.exogenous_features, msg - for f in self.exogenous_features: - output_features.remove(f) - kwargs['output_features'] = output_features - return kwargs - def set_model_params(self, **kwargs): """Set parameters used for training the model @@ -466,16 +417,21 @@ def set_model_params(self, **kwargs): ---------- kwargs : dict Keyword arguments including 'input_resolution', - 'training_features', 'output_features', 'smoothed_features', - 's_enhance', 't_enhance', 'smoothing' + 'lr_features', 'hr_exo_features', 'hr_out_features', + 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - kwargs = self._check_exo_features(**kwargs) - keys = ('input_resolution', 'training_features', 'output_features', - 'train_only_features', 'smoothed_features', 's_enhance', + keys = ('input_resolution', 'lr_features', 'hr_exo_features', + 'hr_out_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing') keys = [k for k in keys if k in kwargs] + hr_exo_feat = kwargs.get('hr_exo_features', []) + msg = (f'Expected high-res exo features {self.hr_exo_features} ' + f'based on model architecture but received "hr_exo_features" ' + f'from data handler: {hr_exo_feat}') + assert list(self.hr_exo_features) == list(hr_exo_feat), msg + for var in keys: val = self.meta.get(var, None) if val is None: @@ -599,43 +555,20 @@ def stdevs(self): """ return self._stdevs - @property - def output_stdevs(self): - """Get the data normalization standard deviation values for only the - output features - - Returns - ------- - np.ndarray - """ - indices = [ - self.training_features.index(f) for f in self.output_features - ] - return self._stdevs[indices] - - @property - def output_means(self): - """Get the data normalization mean values for only the output features - - Returns - ------- - np.ndarray - """ - indices = [ - self.training_features.index(f) for f in self.output_features - ] - return self._means[indices] - def set_norm_stats(self, new_means, new_stdevs): """Set the normalization statistics associated with a data batch handler to model attributes. Parameters ---------- - new_means : list | tuple | np.ndarray - 1D iterable of mean values with same length as number of features. - new_stdevs : list | tuple | np.ndarray - 1D iterable of stdev values with same length as number of features. + new_means : dict | None + Set of mean values for data normalization keyed by feature name. + Can be used to maintain a consistent normalization scheme between + transfer learning domains. + new_stdevs : dict | None + Set of stdev values for data normalization keyed by feature name. + Can be used to maintain a consistent normalization scheme between + transfer learning domains. """ if self._means is not None: @@ -648,10 +581,22 @@ def set_norm_stats(self, new_means, new_stdevs): self._means = new_means self._stdevs = new_stdevs - if not isinstance(self._means, np.ndarray): - self._means = np.array(self._means) - if not isinstance(self._stdevs, np.ndarray): - self._stdevs = np.array(self._stdevs) + if (not isinstance(self._means, dict) + or not isinstance(self._stdevs, dict)): + msg = ('Means and stdevs need to be dictionaries with keys as ' + 'feature names but received means of type ' + f'{type(self._means)} and ' + f'stdevs of type {type(self._stdevs)}') + logger.error(msg) + raise TypeError(msg) + + missing = [f for f in self.lr_features if f not in self._means] + missing += [f for f in self.hr_exo_features + if f not in self._means] + missing += [f for f in self.hr_out_features if f not in self._means] + if any(missing): + msg = (f'Need means for features "{missing}" but did not find ' + f'in new means array: {self._means}') logger.info('Set data normalization mean values: {}'.format( self._means)) @@ -681,14 +626,21 @@ def norm_input(self, low_res): if isinstance(low_res, tf.Tensor): low_res = low_res.numpy() - if any(self._stdevs == 0): - stdevs = np.where(self._stdevs == 0, 1, self._stdevs) + missing = [fn for fn in self.lr_features if fn not in self._means] + if any(missing): + msg = (f'Could not find low-res input features {missing} in ' + f'means/stdevs: {self._means}/{self._stdevs}') + logger.error(msg) + raise KeyError(msg) + + means = np.array([self._means[fn] for fn in self.lr_features]) + stdevs = np.array([self._stdevs[fn] for fn in self.lr_features]) + if any(stdevs == 0): + stdevs = np.where(stdevs == 0, 1, stdevs) msg = 'Some standard deviations are zero.' logger.warning(msg) warn(msg) - else: - stdevs = self._stdevs - low_res = (low_res.copy() - self._means) / stdevs + low_res = (low_res.copy() - means) / stdevs return low_res @@ -709,7 +661,20 @@ def un_norm_output(self, output): if isinstance(output, tf.Tensor): output = output.numpy() - output = (output * self.output_stdevs) + self.output_means + missing = [fn for fn in self.hr_out_features + if fn not in self._means] + if any(missing): + msg = (f'Could not find high-res output features {missing} in ' + f'means/stdevs: {self._means}/{self._stdevs}') + logger.error(msg) + raise KeyError(msg) + + means = [self._means[fn] for fn in self.hr_out_features] + stdevs = [self._stdevs[fn] for fn in self.hr_out_features] + means = np.array(means) + stdevs = np.array(stdevs) + + output = (output * stdevs) + means return output @@ -845,8 +810,9 @@ def get_high_res_exo_input(self, high_res): e.g. {'topography': tf.Tensor(...)} """ exo_data = {} - for feature in self.exogenous_features: - f_idx = self.hr_features.index(feature) + for feature in self.hr_exo_features: + f_idx = self.hr_exo_features.index(feature) + f_idx += len(self.hr_out_features) exo_fdata = high_res[..., f_idx: f_idx + 1] exo_data[feature] = exo_fdata return exo_data @@ -1245,9 +1211,8 @@ def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): return hi_res_exo if norm_in and self._means is not None: - idf = self.training_features.index(exo_name) - hi_res_exo = ((hi_res_exo.copy() - self._means[idf]) - / self._stdevs[idf]) + hi_res_exo = ((hi_res_exo.copy() - self._means[exo_name]) + / self._stdevs[exo_name]) if len(hi_res_exo.shape) == 3: hi_res_exo = np.expand_dims(hi_res_exo, axis=0) @@ -1385,7 +1350,7 @@ def _tf_generate(self, low_res, hi_res_exo=None): return hi_res - @tf.function() + @tf.function def get_single_grad(self, low_res, hi_res_true, diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 1c8c05bf8..80a68fed7 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -72,20 +72,21 @@ def __init__(self, history meta : dict | None Model meta data that describes how the model was created. - means : np.ndarray | list | None - Set of mean values for data normalization with the same length as - number of features. Can be used to maintain a consistent - normalization scheme between transfer learning domains. - stdevs : np.ndarray | list | None - Set of stdev values for data normalization with the same length as - number of features. Can be used to maintain a consistent - normalization scheme between transfer learning domains. + means : dict | None + Set of mean values for data normalization keyed by feature name. + Can be used to maintain a consistent normalization scheme between + transfer learning domains. + stdevs : dict | None + Set of stdev values for data normalization keyed by feature name. + Can be used to maintain a consistent normalization scheme between + transfer learning domains. default_device : str | None Option for default device placement of model weights. If None and a single GPU exists, that GPU will be the default device. If None and multiple GPUs exist, the CPU will be the default device (this was tested as most efficient given the custom multi-gpu strategy - developed in self.run_gradient_descent()) + developed in self.run_gradient_descent()). Examples: "/gpu:0" or + "/cpu:0" name : str | None Optional name for the GAN. """ @@ -116,10 +117,8 @@ def __init__(self, self._gen = self.load_network(gen_layers, 'generator') self._disc = self.load_network(disc_layers, 'discriminator') - self._means = (means if means is None else np.array(means).astype( - np.float32)) - self._stdevs = (stdevs if stdevs is None else np.array(stdevs).astype( - np.float32)) + self._means = means + self._stdevs = stdevs def save(self, out_dir): """Save the GAN with its sub-networks to a directory. @@ -224,8 +223,12 @@ def discriminate(self, hi_res, norm_in=False): hi_res = hi_res.numpy() if norm_in and self._means is not None: + mean_arr = [self._means[fn] for fn in self.hr_out_features] + std_arr = [self._stdevs[fn] for fn in self.hr_out_features] + mean_arr = np.array(mean_arr, dtype=np.float32) + std_arr = np.array(std_arr, dtype=np.float32) hi_res = hi_res if isinstance(hi_res, tf.Tensor) else hi_res.copy() - hi_res = (hi_res - self._means) / self._stdevs + hi_res = (hi_res - mean_arr) / std_arr out = self.discriminator.layers[0](hi_res) for i, layer in enumerate(self.discriminator.layers[1:]): @@ -324,10 +327,6 @@ def model_params(self): ------- dict """ - means = (self._means - if self._means is None else [float(m) for m in self._means]) - stdevs = (self._stdevs if self._stdevs is None else - [float(s) for s in self._stdevs]) config_optm_g = self.get_optimizer_config(self.optimizer) config_optm_d = self.get_optimizer_config(self.optimizer_disc) @@ -338,8 +337,8 @@ def model_params(self): 'version_record': self.version_record, 'optimizer': config_optm_g, 'optimizer_disc': config_optm_d, - 'means': means, - 'stdevs': stdevs, + 'means': self._means, + 'stdevs': self._stdevs, 'meta': self.meta, } @@ -383,7 +382,7 @@ def init_weights(self, lr_shape, hr_shape, device=None): with tf.device(device): hr_exo_data = {} - for feature in self.exogenous_features: + for feature in self.hr_exo_features: hr_exo_data[feature] = hr_exo _ = self._tf_generate(low_res, hr_exo_data) _ = self._tf_discriminate(hi_res) @@ -514,6 +513,7 @@ def calc_loss_disc(disc_out_true, disc_out_gen): return loss_disc + @tf.function def calc_loss(self, hi_res_true, hi_res_gen, @@ -645,7 +645,8 @@ def train_epoch(self, present, each batch from the batch_handler will be divided up between the GPUs and the resulting gradient from each GPU will constitute a single gradient descent step with the nominal learning - rate that the model was initialized with. + rate that the model was initialized with. If true and multiple gpus + are found, default_device device should be set to /gpu:0 Returns ------- @@ -845,7 +846,8 @@ def train(self, present, each batch from the batch_handler will be divided up between the GPUs and the resulting gradient from each GPU will constitute a single gradient descent step with the nominal learning - rate that the model was initialized with. + rate that the model was initialized with. If true and multiple gpus + are found, default_device device should be set to /gpu:0 """ self.set_norm_stats(batch_handler.means, batch_handler.stds) @@ -854,9 +856,9 @@ def train(self, s_enhance=batch_handler.s_enhance, t_enhance=batch_handler.t_enhance, smoothing=batch_handler.smoothing, - training_features=batch_handler.training_features, - train_only_features=batch_handler.train_only_features, - output_features=batch_handler.output_features, + lr_features=batch_handler.lr_features, + hr_exo_features=batch_handler.hr_exo_features, + hr_out_features=batch_handler.hr_out_features, smoothed_features=batch_handler.smoothed_features) epochs = list(range(n_epoch)) diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index 74d411005..ded29e251 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -45,14 +45,14 @@ def __init__(self, gen_layers, history meta : dict | None Model meta data that describes how the model was created. - means : np.ndarray | list | None - Set of mean values for data normalization with the same length as - number of features. Can be used to maintain a consistent - normalization scheme between transfer learning domains. - stdevs : np.ndarray | list | None - Set of stdev values for data normalization with the same length as - number of features. Can be used to maintain a consistent - normalization scheme between transfer learning domains. + means : dict | None + Set of mean values for data normalization keyed by feature name. + Can be used to maintain a consistent normalization scheme between + transfer learning domains. + stdevs : dict | None + Set of stdev values for data normalization keyed by feature name. + Can be used to maintain a consistent normalization scheme between + transfer learning domains. default_device : str | None Option for default device placement of model weights. If None and a single GPU exists, that GPU will be the default device. If None and @@ -84,10 +84,8 @@ def __init__(self, gen_layers, self._gen = self.load_network(gen_layers, 'generator') - self._means = (means if means is None - else np.array(means).astype(np.float32)) - self._stdevs = (stdevs if stdevs is None - else np.array(stdevs).astype(np.float32)) + self._means = means + self._stdevs = stdevs def save(self, out_dir): """Save the model with its sub-networks to a directory. @@ -174,10 +172,6 @@ def model_params(self): ------- dict """ - means = (self._means if self._means is None - else [float(m) for m in self._means]) - stdevs = (self._stdevs if self._stdevs is None - else [float(s) for s in self._stdevs]) config_optm_g = self.get_optimizer_config(self.optimizer) @@ -189,8 +183,8 @@ def model_params(self): 'num_par': num_par, 'version_record': self.version_record, 'optimizer': config_optm_g, - 'means': means, - 'stdevs': stdevs, + 'means': self._means, + 'stdevs': self._stdevs, 'meta': self.meta, } @@ -395,9 +389,9 @@ def train(self, batch_handler, s_enhance=batch_handler.s_enhance, t_enhance=batch_handler.t_enhance, smoothing=batch_handler.smoothing, - train_only_features=batch_handler.train_only_features, - training_features=batch_handler.training_features, - output_features=batch_handler.output_features, + lr_features=batch_handler.lr_features, + hr_exo_features=batch_handler.hr_exo_features, + hr_out_features=batch_handler.hr_out_features, smoothed_features=batch_handler.smoothed_features) epochs = list(range(n_epoch)) diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index 378b068c3..79e825f4b 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -17,11 +17,11 @@ class LinearInterp(AbstractInterface): """Simple model to do linear interpolation on the spatial and temporal axes """ - def __init__(self, features, s_enhance, t_enhance, t_centered=False): + def __init__(self, lr_features, s_enhance, t_enhance, t_centered=False): """ Parameters ---------- - features : list + lr_features : list List of feature names that this model will operate on for both input and output. This must match the feature axis ordering in the array input to generate(). @@ -35,7 +35,7 @@ def __init__(self, features, s_enhance, t_enhance, t_centered=False): time-centered (e.g. interp 01:00 02:00 to 00:45 01:15 01:45 02:15) """ - self._features = features + self._lr_features = lr_features self._s_enhance = s_enhance self._t_enhance = t_enhance self._t_centered = t_centered @@ -78,27 +78,31 @@ class init args. @property def meta(self): """Get meta data dictionary that defines the model params""" - return {'features': self._features, + return {'lr_features': self._lr_features, 's_enhance': self._s_enhance, 't_enhance': self._t_enhance, 't_centered': self._t_centered, - 'training_features': self.training_features, - 'output_features': self.output_features, + 'hr_out_features': self.hr_out_features, 'class': self.__class__.__name__, } @property - def training_features(self): + def lr_features(self): """Get the list of input feature names that the generative model was trained on. """ - return self._features + return self._lr_features @property - def output_features(self): + def hr_out_features(self): """Get the list of output feature names that the generative model outputs""" - return self._features + return self._lr_features + + @property + def hr_exo_features(self): + """Returns an empty list for LinearInterp model""" + return [] def save(self, out_dir): """ @@ -142,7 +146,7 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, int(low_res.shape[1] * self._s_enhance), int(low_res.shape[2] * self._s_enhance), int(low_res.shape[3] * self._t_enhance), - len(self.output_features)) + len(self.hr_out_features)) logger.debug('LinearInterp model with s_enhance of {} ' 'and t_enhance of {} ' 'downscaling low-res shape {} to high-res shape {}' diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 01cb40cfe..a500c20db 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -236,16 +236,27 @@ def meta(self): return tuple(model.meta for model in self.models) @property - def training_features(self): - """Get the list of input feature names that the first generative model - in this MultiStepGan requires as input.""" - return self.models[0].meta.get('training_features', None) + def lr_features(self): + """Get a list of low-resolution features input to the generative model. + This includes low-resolution features that might be supplied + exogenously at inference time but that were in the low-res batches + during training""" + return self.models[0].lr_features @property - def output_features(self): - """Get the list of output feature names that the last generative model - in this MultiStepGan outputs.""" - return self.models[-1].meta.get('output_features', None) + def hr_out_features(self): + """Get the list of high-resolution output feature names that the + generative model outputs.""" + return self.models[-1].hr_out_features + + @property + def hr_exo_features(self): + """Get list of high-resolution exogenous filter names the model uses. + For the multi-step model, each entry in this list corresponds to one of + the single-step models and is itself a list of hr_exo_features for that + model. + """ + return [model.hr_exo_features for model in self.models] @property def model_params(self): @@ -440,8 +451,8 @@ def preflight(self): '{} (solar) and {} (wind)'.format(s_enh, w_enh)) assert np.product(s_enh) == np.product(w_enh), msg - s_t_feat = self.spatial_solar_models.training_features - s_o_feat = self.spatial_solar_models.output_features + s_t_feat = self.spatial_solar_models.lr_features + s_o_feat = self.spatial_solar_models.hr_out_features msg = ('Solar spatial enhancement models need to take ' '"clearsky_ratio" as the only input and output feature but ' 'received models that need {} and output {}' @@ -449,14 +460,14 @@ def preflight(self): assert s_t_feat == ['clearsky_ratio'], msg assert s_o_feat == ['clearsky_ratio'], msg - temp_solar_feats = self.temporal_solar_models.training_features + temp_solar_feats = self.temporal_solar_models.lr_features msg = ('Input feature 0 for the temporal_solar_models should be ' '"clearsky_ratio" but received: {}' .format(temp_solar_feats)) assert temp_solar_feats[0] == 'clearsky_ratio', msg - spatial_out_features = (self.spatial_wind_models.output_features - + self.spatial_solar_models.output_features) + spatial_out_features = (self.spatial_wind_models.hr_out_features + + self.spatial_solar_models.hr_out_features) missing = [fn for fn in temp_solar_feats if fn not in spatial_out_features] msg = ('Solar temporal model needs features {} that were not ' @@ -506,26 +517,27 @@ def meta(self): + self.temporal_solar_models.meta) @property - def training_features(self): - """Get the list of input feature names that the first spatial - generative models in this SolarMultiStepGan requires as input. - This includes the solar + wind training features.""" - return (self.spatial_solar_models.training_features - + self.spatial_wind_models.training_features) + def lr_features(self): + """Get a list of low-resolution features input to the generative model. + This includes low-resolution features that might be supplied + exogenously at inference time but that were in the low-res batches + during training""" + return (self.spatial_solar_models.lr_features + + self.spatial_wind_models.lr_features) @property - def output_features(self): + def hr_out_features(self): """Get the list of output feature names that the last solar spatiotemporal generative model in this SolarMultiStepGan outputs.""" - return self.temporal_solar_models.output_features + return self.temporal_solar_models.hr_out_features @property def idf_wind(self): """Get an array of feature indices for the subset of features required for the spatial_wind_models. This excludes topography which is assumed to be provided as exogenous_data.""" - return np.array([self.training_features.index(fn) for fn in - self.spatial_wind_models.training_features + return np.array([self.lr_features.index(fn) for fn in + self.spatial_wind_models.lr_features if fn != 'topography']) @property @@ -534,8 +546,8 @@ def idf_wind_out(self): required for input to the temporal_solar_models. Typically this is the indices of U_200m + V_200m from the output features of spatial_wind_models""" - temporal_solar_features = self.temporal_solar_models.training_features - return np.array([self.spatial_wind_models.output_features.index(fn) + temporal_solar_features = self.temporal_solar_models.lr_features + return np.array([self.spatial_wind_models.hr_out_features.index(fn) for fn in temporal_solar_features[1:]]) @property @@ -543,8 +555,8 @@ def idf_solar(self): """Get an array of feature indices for the subset of features required for the spatial_solar_models. This excludes topography which is assumed to be provided as exogenous_data.""" - return np.array([self.training_features.index(fn) for fn in - self.spatial_solar_models.training_features + return np.array([self.lr_features.index(fn) for fn in + self.spatial_solar_models.lr_features if fn != 'topography']) def generate(self, low_res, norm_in=True, un_norm_out=True, @@ -557,7 +569,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, low_res : np.ndarray Low-resolution input data to the 1st step spatial GAN, which is a 4D array of shape: (temporal, spatial_1, spatial_2, n_features). - This should include all of the self.training_features which is a + This should include all of the self.lr_features which is a concatenation of both the solar and wind spatial model features. The topography feature might be removed from this input and present in the exogenous_data input. diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index f0b93b8a5..ff695846b 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -42,14 +42,14 @@ class SurfaceSpatialMetModel(LinearInterp): """Weight for the delta-topography feature for the relative humidity linear regression model.""" - def __init__(self, features, s_enhance, noise_adders=None, + def __init__(self, lr_features, s_enhance, noise_adders=None, temp_lapse=None, w_delta_temp=None, w_delta_topo=None, pres_div=None, pres_exp=None, interp_method='LANCZOS', fix_bias=True): """ Parameters ---------- - features : list + lr_features : list List of feature names that this model will operate on for both input and output. This must match the feature axis ordering in the array input to generate(). Typically this is a list containing: @@ -63,11 +63,11 @@ def __init__(self, features, s_enhance, noise_adders=None, Option to add gaussian noise to spatial model output. Noise will be normally distributed with mean of 0 and standard deviation = noise_adders. noise_adders can be a single value or a list - corresponding to the features list. None is no noise. The addition - of noise has been shown to help downstream temporal-only models - produce diurnal cycles in regions where there is minimal change in - topography. A noise_adders around 0.07C (temperature) and 0.1% - (relative humidity) have been shown to be effective. This is + corresponding to the lr_features list. None is no noise. The + addition of noise has been shown to help downstream temporal-only + models produce diurnal cycles in regions where there is minimal + change in topography. A noise_adders around 0.07C (temperature) and + 0.1% (relative humidity) have been shown to be effective. This is unnecessary if daily min/max temperatures are provided as low res training features. temp_lapse : None | float @@ -98,7 +98,7 @@ def __init__(self, features, s_enhance, noise_adders=None, low-resolution deviation from the input data """ - self._features = features + self._lr_features = lr_features self._s_enhance = s_enhance self._noise_adders = noise_adders self._temp_lapse = temp_lapse or self.TEMP_LAPSE @@ -111,7 +111,7 @@ def __init__(self, features, s_enhance, noise_adders=None, self._interp_method = getattr(Image.Resampling, interp_method) if isinstance(self._noise_adders, (int, float)): - self._noise_adders = [self._noise_adders] * len(self._features) + self._noise_adders = [self._noise_adders] * len(self._lr_features) def __len__(self): """Get number of model steps (match interface of MultiStepGan)""" @@ -162,21 +162,21 @@ def input_dims(self): @property def feature_inds_temp(self): """Get the feature index values for the temperature features.""" - inds = [i for i, name in enumerate(self._features) + inds = [i for i, name in enumerate(self._lr_features) if fnmatch(name, 'temperature_*')] return inds @property def feature_inds_pres(self): """Get the feature index values for the pressure features.""" - inds = [i for i, name in enumerate(self._features) + inds = [i for i, name in enumerate(self._lr_features) if fnmatch(name, 'pressure_*')] return inds @property def feature_inds_rh(self): """Get the feature index values for the relative humidity features.""" - inds = [i for i, name in enumerate(self._features) + inds = [i for i, name in enumerate(self._lr_features) if fnmatch(name, 'relativehumidity_*')] return inds @@ -195,14 +195,14 @@ def _get_temp_rh_ind(self, idf_rh): Index in the feature list for a temperature_*m feature with the same hub height as the idf_rh input. """ - name_rh = self._features[idf_rh] + name_rh = self._lr_features[idf_rh] hh_suffix = name_rh.split('_')[-1] idf_temp = None for i in self.feature_inds_temp: - same_hh = self._features[i].endswith(hh_suffix) + same_hh = self._lr_features[i].endswith(hh_suffix) not_minmax = not any(mm in name_rh for mm in ('_min_', '_max_')) - both_mins = '_min_' in name_rh and '_min_' in self._features[i] - both_maxs = '_max_' in name_rh and '_max_' in self._features[i] + both_mins = '_min_' in name_rh and '_min_' in self._lr_features[i] + both_maxs = '_max_' in name_rh and '_max_' in self._lr_features[i] if same_hh and (not_minmax or both_mins or both_maxs): idf_temp = i @@ -210,7 +210,8 @@ def _get_temp_rh_ind(self, idf_rh): if idf_temp is None: msg = ('Could not find temperature feature corresponding to ' - '"{}" in feature list: {}'.format(name_rh, self._features)) + '"{}" in feature list: {}' + .format(name_rh, self._lr_features)) logger.error(msg) raise KeyError(msg) @@ -522,7 +523,7 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, hr_shape = (len(low_res), int(low_res.shape[1] * self._s_enhance), int(low_res.shape[2] * self._s_enhance), - len(self.output_features)) + len(self.hr_out_features)) logger.debug('SurfaceSpatialMetModel with s_enhance of {} ' 'downscaling low-res shape {} to high-res shape {}' .format(self._s_enhance, low_res.shape, hr_shape)) @@ -567,8 +568,8 @@ def meta(self): 'weight_for_delta_topo': self._w_delta_topo, 'pressure_divisor': self._pres_div, 'pressure_exponent': self._pres_exp, - 'training_features': self.training_features, - 'output_features': self.output_features, + 'lr_features': self.lr_features, + 'hr_out_features': self.hr_out_features, 'interp_method': str(self._interp_method), 'fix_bias': self._fix_bias, 'class': self.__class__.__name__, diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 722239ee2..2173d7e4c 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -778,7 +778,8 @@ def __init__(self, self.t_enhancements = [model.t_enhance for model in models] self.s_enhance = np.product(self.s_enhancements) self.t_enhance = np.product(self.t_enhancements) - self.output_features = model.output_features + self.output_features = model.hr_out_features + assert len(self.output_features) > 0, 'No output features!' self.fwp_slicer = ForwardPassSlicer(self.grid_shape, self.raw_tsteps, @@ -1084,8 +1085,10 @@ def __init__(self, strategy, chunk_index=0, node_index=0): raise KeyError(msg) self.model = model_class.load(**self.model_kwargs, verbose=False) - self.features = self.model.training_features - self.output_features = self.model.output_features + self.features = self.model.lr_features + self.output_features = self.model.hr_out_features + assert len(self.features) > 0, 'No input features!' + assert len(self.output_features) > 0, 'No output features!' self._file_paths = strategy.file_paths self.max_workers = strategy.max_workers @@ -1964,7 +1967,7 @@ def run_chunk(self): logger.info(f'Saving forward pass output to {self.out_file}.') self.output_handler_class._write_output( data=self.output_data, - features=self.model.output_features, + features=self.model.hr_out_features, lat_lon=self.hr_lat_lon, times=self.hr_times, out_file=self.out_file, diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index c1244366c..cebb87176 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -290,12 +290,17 @@ def enforce_limits(features, data): data : ndarray Array of feature data with physical limits enforced """ - maxes = [H5_ATTRS[Feature.get_basename(f)].get('max', np.inf) - for f in features] - mins = [H5_ATTRS[Feature.get_basename(f)].get('min', -np.inf) - for f in features] + maxs = [] + mins = [] + for fn in features: + max = H5_ATTRS[Feature.get_basename(fn)].get('max', np.inf) + min = H5_ATTRS[Feature.get_basename(fn)].get('min', -np.inf) + logger.debug(f'Enforcing range of ({max}, {min} for "{fn}")') + maxs.append(max) + mins.append(min) + data = np.maximum(data, mins) - data = np.minimum(data, maxes) + data = np.minimum(data, maxs) return data diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 977fe4e05..9c10b85f5 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -2,9 +2,9 @@ Sup3r batch_handling module. @author: bbenton """ +import json import logging import os -import pickle from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime as dt @@ -72,26 +72,6 @@ def high_res(self): """Get the high-resolution data for the batch.""" return self._high_res - @staticmethod - def reduce_features(high_res, output_features_ind=None): - """Remove any feature channels that are only intended for the low-res - training input. - - Parameters - ---------- - high_res : np.ndarray - 4D | 5D array - (batch_size, spatial_1, spatial_2, features) - (batch_size, spatial_1, spatial_2, temporal, features) - output_features_ind : list | np.ndarray | None - List/array of feature channel indices that are used for generative - output, without any feature indices used only for training. - """ - if output_features_ind is None: - return high_res - else: - return high_res[..., output_features_ind] - # pylint: disable=W0613 @classmethod def get_coarse_batch(cls, @@ -99,9 +79,8 @@ def get_coarse_batch(cls, s_enhance, t_enhance=1, temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - training_features=None, + hr_features_ind=None, + features=None, smoothing=None, smoothing_ignore=None, ): @@ -123,12 +102,10 @@ def get_coarse_batch(cls, temporal_coarsening_method : str Method to use for temporal coarsening. Can be subsample, average, min, max, or total - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. - output_features : list - List of Generative model output feature names - training_features : list | None + features : list | None Ordered list of training features input to the generative model smoothing : float | None Standard deviation to use for gaussian filtering of the coarse @@ -147,8 +124,11 @@ def get_coarse_batch(cls, """ low_res = spatial_coarsening(high_res, s_enhance) - if training_features is None: - training_features = [None] * low_res.shape[-1] + if features is None: + features = [None] * low_res.shape[-1] + + if hr_features_ind is None: + hr_features_ind = np.arange(high_res.shape[-1]) if smoothing_ignore is None: smoothing_ignore = [] @@ -157,9 +137,9 @@ def get_coarse_batch(cls, low_res = temporal_coarsening(low_res, t_enhance, temporal_coarsening_method) - low_res = smooth_data(low_res, training_features, smoothing_ignore, + low_res = smooth_data(low_res, features, smoothing_ignore, smoothing) - high_res = cls.reduce_features(high_res, output_features_ind) + high_res = high_res[..., hr_features_ind] batch = cls(low_res, high_res) return batch @@ -174,11 +154,10 @@ class ValidationData: def __init__(self, data_handlers, batch_size=8, - s_enhance=3, + s_enhance=1, t_enhance=1, temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, + hr_features_ind=None, smoothing=None, smoothing_ignore=None): """ @@ -199,11 +178,9 @@ def __init__(self, Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. - output_features : list - List of Generative model output feature names smoothing : float | None Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low @@ -228,8 +205,7 @@ def __init__(self, self._remaining_observations = len(self.val_indices) self.temporal_coarsening_method = temporal_coarsening_method self._i = 0 - self.output_features_ind = output_features_ind - self.output_features = output_features + self.hr_features_ind = hr_features_ind self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore self.current_batch_indices = [] @@ -332,10 +308,9 @@ def batch_next(self, high_res): self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, - output_features_ind=self.output_features_ind, + hr_features_ind=self.hr_features_ind, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features) + smoothing_ignore=self.smoothing_ignore) def __next__(self): """Get validation data batch @@ -384,7 +359,7 @@ class BatchHandler: def __init__(self, data_handlers, batch_size=8, - s_enhance=3, + s_enhance=1, t_enhance=1, means=None, stds=None, @@ -410,15 +385,14 @@ def __init__(self, t_enhance : int Factor by which to coarsen temporal dimension of the high resolution data to generate low res data - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data - features. If not None and norm is True these will be used for - normalization - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data - features. If not None and norm is True these will be used form + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. if None, this will be calculated. if norm is + true these will be used for data normalization + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. if None, this will + be calculated. if norm is true these will be used for data normalization norm : bool Whether to normalize the data or not @@ -431,9 +405,11 @@ def __init__(self, average over t_enhance time steps, total will sum over t_enhance time steps stdevs_file : str | None - Path to stdevs data or where to save data after calling get_stats + Optional .json path to stdevs data or where to save data after + calling get_stats means_file : str | None - Path to means data or where to save data after calling get_stats + Optional .json path to means data or where to save data after + calling get_stats overwrite_stats : bool Whether to overwrite stats cache files. smoothing : float | None @@ -470,6 +446,9 @@ def __init__(self, self._norm_workers = worker_kwargs.get('norm_workers', norm_workers) self._load_workers = worker_kwargs.get('load_workers', load_workers) + data_handlers = (data_handlers + if isinstance(data_handlers, (list, tuple)) + else [data_handlers]) msg = 'All data handlers must have the same sample_shape' handler_shapes = np.array([d.sample_shape for d in data_handlers]) assert np.all(handler_shapes[0] == handler_shapes), msg @@ -495,7 +474,7 @@ def __init__(self, self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore or [] self.smoothed_features = [ - f for f in self.training_features if f not in self.smoothing_ignore + f for f in self.features if f not in self.smoothing_ignore ] logger.info(f'Initializing BatchHandler with ' @@ -522,8 +501,7 @@ def __init__(self, s_enhance=s_enhance, t_enhance=t_enhance, temporal_coarsening_method=temporal_coarsening_method, - output_features_ind=self.output_features_ind, - output_features=self.output_features, + hr_features_ind=self.hr_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, ) @@ -576,38 +554,45 @@ def norm_workers(self): features""" proc_mem = 2 * self.feature_mem norm_workers = estimate_max_workers(self._norm_workers, proc_mem, - len(self.training_features)) + len(self.features)) return norm_workers @property - def training_features(self): + def features(self): """Get the ordered list of feature names held in this object's data handlers""" return self.data_handlers[0].features @property - def train_only_features(self): - """Get the ordered list of feature names used only for training which - will not be stored in batch.high_res""" - return self.data_handlers[0].train_only_features + def lr_features(self): + """Get a list of low-resolution features. All low-resolution features + are used for training.""" + return self.data_handlers[0].features @property - def output_features(self): - """Get the ordered list of feature names held in this object's - data handlers""" - return self.data_handlers[0].output_features + def hr_exo_features(self): + """Get a list of high-resolution features that are only used for + training e.g., mid-network high-res topo injection.""" + return self.data_handlers[0].hr_exo_features + + @property + def hr_out_features(self): + """Get a list of low-resolution features that are intended to be output + by the GAN.""" + return self.data_handlers[0].hr_out_features @property - def output_features_ind(self): - """Get the feature channel indices that should be used for the - generated output features""" - if self.training_features == self.output_features: - return None + def hr_features_ind(self): + """Get the high-resolution feature channel indices that should be + included for training. Any high-resolution features that are only + included in the data handler to be coarsened for the low-res input are + removed""" + hr_features = list(self.hr_out_features) + list(self.hr_exo_features) + if list(self.features) == hr_features: + return np.arange(len(self.features)) else: - out = [ - i for i, feature in enumerate(self.training_features) - if feature in self.output_features - ] + out = [i for i, feature in enumerate(self.features) + if feature in hr_features] return out @property @@ -626,20 +611,22 @@ def shape(self): n_lats = self.data_handlers[0].shape[0] return (n_lats, n_lons, time_steps, self.data_handlers[0].shape[-1]) - def parallel_normalization(self): - """Normalize data in all data handlers in parallel.""" + def _parallel_normalization(self): + """Normalize data in all data handlers in parallel or serial depending + on norm_workers.""" logger.info(f'Normalizing {len(self.data_handlers)} data handlers.') - max_workers = self.load_workers + max_workers = self.norm_workers if max_workers == 1: - for d in self.data_handlers: - d.normalize(self.means, self.stds) + for dh in self.data_handlers: + dh.normalize(self.means, self.stds) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} now = dt.now() - for i, d in enumerate(self.data_handlers): - future = exe.submit(d.normalize, self.means, self.stds) - futures[future] = i + for idh, dh in enumerate(self.data_handlers): + future = exe.submit(dh.normalize, self.means, self.stds, + max_workers=1) + futures[future] = idh logger.info(f'Started normalizing {len(self.data_handlers)} ' f'data handlers in {dt.now() - now}.') @@ -689,35 +676,27 @@ def load_handler_data(self): def _get_stats(self): """Get standard deviations and means for training features in parallel.""" - logger.info(f'Calculating stats for {len(self.training_features)} ' + logger.info(f'Calculating stats for {len(self.features)} ' 'features.') - max_workers = self.norm_workers - if max_workers == 1: - for f in self.training_features: - self.get_stats_for_feature(f) - else: - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for i, f in enumerate(self.training_features): - future = exe.submit(self.get_stats_for_feature, f) - futures[future] = i - - logger.info('Started calculating stats for ' - f'{len(self.training_features)} features in ' - f'{dt.now() - now}.') - - for i, future in enumerate(as_completed(futures)): - try: - future.result() - except Exception as e: - msg = ('Error calculating stats for ' - f'{self.training_features[futures[future]]}') - logger.exception(msg) - raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of ' - f'{len(self.training_features)} stats ' - 'calculated.') + for feature in self.features: + logger.debug(f'Calculating mean/stdev for "{feature}"') + self.means[feature] = 0 + self.stds[feature] = 0 + max_workers = self.stats_workers + + if max_workers is None or max_workers >= 1: + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = {} + for idh, dh in enumerate(self.data_handlers): + future = exe.submit(dh._get_stats) + futures[future] = idh + + for i, _ in enumerate(as_completed(futures)): + logger.debug(f'{i+1} out of {len(self.data_handlers)} ' + 'means calculated.') + + self.means[feature] = self._get_feature_means(feature) + self.stds[feature] = self._get_feature_stdev(feature) def __len__(self): """Use user input of n_batches to specify length @@ -735,10 +714,15 @@ def check_cached_stats(self): Returns ------- - means : ndarray - Array of means for each feature - stds : ndarray - Array of stdevs for each feature + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. if None, this will be calculated. if norm is + true these will be used for data normalization + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. if None, this will + be calculated. if norm is true these will be used for data + normalization """ stdevs_check = (self.stdevs_file is not None and not self.overwrite_stats) @@ -747,44 +731,37 @@ def check_cached_stats(self): means_check = means_check and os.path.exists(self.means_file) if stdevs_check and means_check: logger.info(f'Loading stdevs from {self.stdevs_file}') - with open(self.stdevs_file, 'rb') as fh: - self.stds = pickle.load(fh) + with open(self.stdevs_file, 'r') as fh: + self.stds = json.load(fh) logger.info(f'Loading means from {self.means_file}') - with open(self.means_file, 'rb') as fh: - self.means = pickle.load(fh) + with open(self.means_file, 'r') as fh: + self.means = json.load(fh) msg = ('The training features and cached statistics are ' 'incompatible. Number of training features is ' - f'{len(self.training_features)} and number of stats is' + f'{len(self.features)} and number of stats is' f' {len(self.stds)}') - check = len(self.means) == len(self.training_features) - check = check and (len(self.stds) == len(self.training_features)) + check = len(self.means) == len(self.features) + check = check and (len(self.stds) == len(self.features)) assert check, msg return self.means, self.stds def cache_stats(self): """Saved stdevs and means to cache files if files are not None""" - if self.stdevs_file is not None: - logger.info(f'Saving stdevs to {self.stdevs_file}') - basedir = os.path.dirname(self.stdevs_file) - if not os.path.exists(basedir): - os.makedirs(basedir) - with open(self.stdevs_file, 'wb') as fh: - pickle.dump(self.stds, fh) - if self.means_file is not None: - logger.info(f'Saving means to {self.means_file}') - basedir = os.path.dirname(self.means_file) - if not os.path.exists(basedir): - os.makedirs(basedir) - with open(self.means_file, 'wb') as fh: - pickle.dump(self.means, fh) + iter = ((self.means_file, self.means), (self.stdevs_file, self.stds)) + for fp, data in iter: + if fp is not None: + logger.info(f'Saving stats to {fp}') + os.makedirs(os.path.dirname(fp), exist_ok=True) + with open(fp, 'w') as fh: + json.dump(data, fh) def get_stats(self): """Get standard deviations and means for all data features""" - self.means = np.zeros((self.shape[-1]), dtype=np.float32) - self.stds = np.zeros((self.shape[-1]), dtype=np.float32) + self.means = {} + self.stds = {} now = dt.now() logger.info('Calculating stdevs/means.') @@ -792,98 +769,23 @@ def get_stats(self): logger.info(f'Finished calculating stats in {dt.now() - now}.') self.cache_stats() - def get_handler_mean(self, feature_idx, handler_idx): - """Get feature mean for a given handler - - Parameters - ---------- - feature_idx : int - Index of feature to get mean for - handler_idx : int - Index of data handler to get mean for - - Returns - ------- - float - Feature mean - """ - return np.nanmean(self.data_handlers[handler_idx].data[..., - feature_idx]) - - def get_handler_variance(self, feature_idx, handler_idx, mean): - """Get feature variance for a given handler - - Parameters - ---------- - feature_idx : int - Index of feature to get variance for - handler_idx : int - Index of data handler to get variance for - mean : float - Mean for the given handler and feature - - Returns - ------- - float - Feature variance - """ - istd = self.data_handlers[handler_idx].data[..., feature_idx] - mean - return np.nanmean(istd**2) - - def get_stats_for_feature(self, feature): - """Get standard deviation and mean for requested feature - - Parameters - ---------- - feature : str - Feature to get stats for - max_workers : int | None - Max number of workers to use for parallel stats calculations. If - None the max number of available workers will be used. - """ - idx = self.training_features.index(feature) - logger.debug(f'Calculating stdev/mean for {feature}') - max_workers = self.stats_workers - self.means[idx] = self.get_means_for_feature(feature, max_workers) - self.stds[idx] = self.get_stdevs_for_feature(feature, max_workers) - - def get_means_for_feature(self, feature, max_workers=None): + def _get_feature_means(self, feature): """Get mean for requested feature Parameters ---------- feature : str Feature to get mean for - max_workers : int | None - Max number of workers to use for parallel stats calculations. If - None the max number of available workers will be used. """ - idx = self.training_features.index(feature) - logger.debug(f'Calculating mean for {feature}') - if max_workers == 1: - for didx, _ in enumerate(self.data_handlers): - self.means[idx] += (self.handler_weights[didx] - * self.get_handler_mean(idx, didx)) - else: - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for didx, _ in enumerate(self.data_handlers): - future = exe.submit(self.get_handler_mean, idx, didx) - futures[future] = didx - logger.info('Started calculating means for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.') + logger.debug(f'Calculating mean for {feature}') + for idh, dh in enumerate(self.data_handlers): + self.means[feature] += (self.handler_weights[idh] + * dh.means[feature]) - for i, future in enumerate(as_completed(futures)): - self.means[idx] += (self.handler_weights[futures[future]] - * future.result()) - logger.debug(f'{i+1} out of {len(self.data_handlers)} ' - 'means calculated.') - return self.means[idx] + return self.means[feature] - def get_stdevs_for_feature(self, feature, max_workers=None): + def _get_feature_stdev(self, feature): """Get stdev for requested feature NOTE: We compute the variance across all handlers as a pooled variance @@ -894,36 +796,16 @@ def get_stdevs_for_feature(self, feature, max_workers=None): ---------- feature : str Feature to get stdev for - max_workers : int | None - Max number of workers to use for parallel stats calculations. If - None the max number of available workers will be used. """ - idx = self.training_features.index(feature) + logger.debug(f'Calculating stdev for {feature}') - if max_workers == 1: - for didx, _ in enumerate(self.data_handlers): - hstd = self.get_handler_variance(idx, didx, self.means[idx]) - self.stds[idx] += (hstd * self.handler_weights[didx]) - else: - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for didx, _ in enumerate(self.data_handlers): - future = exe.submit(self.get_handler_variance, idx, didx, - self.means[idx]) - futures[future] = didx + for idh, dh in enumerate(self.data_handlers): + variance = dh.stds[feature]**2 + self.stds[feature] += (variance * self.handler_weights[idh]) - logger.info('Started calculating stdevs for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.') + self.stds[feature] = np.sqrt(self.stds[feature]) - for i, future in enumerate(as_completed(futures)): - self.stds[idx] += (self.handler_weights[futures[future]] - * future.result()) - logger.debug(f'{i+1} out of {len(self.data_handlers)} ' - 'stdevs calculated.') - self.stds[idx] = np.sqrt(self.stds[idx]) - return self.stds[idx] + return self.stds[feature] def normalize(self, means=None, stds=None): """Compute means and stds for each feature across all datasets and @@ -932,30 +814,38 @@ def normalize(self, means=None, stds=None): Parameters ---------- - means : ndarray | None - Array of means for each feature. If None it will be calculated. - stds : ndarray | None - Array of stdevs for each feature. If None it will be calculated. + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. if None, this will be calculated. if norm is + true these will be used for data normalization + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. if None, this will + be calculated. if norm is true these will be used for data + normalization """ if means is None or stds is None: self.get_stats() elif means is not None and stds is not None: - if not np.array_equal(means, self.means) or not np.array_equal( - stds, self.stds): - self.unnormalize() - self.means = means - self.stds = stds + means0, means1 = list(self.means.values()), list(means.values()) + stds0, stds1 = list(self.stds.values()), list(stds.values()) + if (not np.array_equal(means0, means1) + or not np.array_equal(stds0, stds1)): + msg = (f'Normalization requested with new means/stdevs ' + f'{means1}/{stds1} that ' + f'dont match previous values: {means0}/{stds0}') + logger.info(msg) + raise ValueError(msg) + else: + self.means = means + self.stds = stds + now = dt.now() logger.info('Normalizing data in each data handler.') - self.parallel_normalization() + self._parallel_normalization() logger.info('Finished normalizing data in all data handlers in ' f'{dt.now() - now}.') - def unnormalize(self): - """Remove normalization from stored means and stds""" - for d in self.data_handlers: - d.unnormalize(self.means, self.stds) - def __iter__(self): self._i = 0 return self @@ -986,9 +876,8 @@ def __next__(self): self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, - output_features_ind=self.output_features_ind, - output_features=self.output_features, - training_features=self.training_features, + hr_features_ind=self.hr_features_ind, + features=self.features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore) @@ -1046,8 +935,7 @@ def __next__(self): obs_hourly, obs_daily_avg = handler.get_next() self.current_batch_indices.append(handler.current_obs_index) - obs_hourly = self.BATCH_CLASS.reduce_features( - obs_hourly, self.output_features_ind) + obs_hourly = obs_hourly[..., self.hr_features_ind] if low_res is None: lr_shape = (self.batch_size, *obs_daily_avg.shape) @@ -1061,16 +949,16 @@ def __next__(self): high_res = self.reduce_high_res_sub_daily(high_res) low_res = spatial_coarsening(low_res, self.s_enhance) - if (self.output_features is not None - and 'clearsky_ratio' in self.output_features): - i_cs = self.output_features.index('clearsky_ratio') + if (self.hr_out_features is not None + and 'clearsky_ratio' in self.hr_out_features): + i_cs = self.hr_out_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: feat_iter = [ j for j in range(low_res.shape[-1]) - if self.training_features[j] not in self.smoothing_ignore + if self.features[j] not in self.smoothing_ignore ] for i in range(low_res.shape[0]): for j in feat_iter: @@ -1166,19 +1054,18 @@ def __next__(self): low_res = low_res[:, :, :, 0, :] high_res = high_res[:, :, :, 0, :] - high_res = self.BATCH_CLASS.reduce_features(high_res, - self.output_features_ind) + high_res = high_res[..., self.hr_features_ind] - if (self.output_features is not None - and 'clearsky_ratio' in self.output_features): - i_cs = self.output_features.index('clearsky_ratio') + if (self.hr_out_features is not None + and 'clearsky_ratio' in self.hr_out_features): + i_cs = self.hr_out_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: feat_iter = [ j for j in range(low_res.shape[-1]) - if self.training_features[j] not in self.smoothing_ignore + if self.features[j] not in self.smoothing_ignore ] for i in range(low_res.shape[0]): for j in feat_iter: @@ -1207,8 +1094,8 @@ def __next__(self): batch = self.BATCH_CLASS.get_coarse_batch( high_res, self.s_enhance, - output_features_ind=self.output_features_ind, - training_features=self.training_features, + hr_features_ind=self.hr_features_ind, + features=self.features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore) @@ -1295,10 +1182,9 @@ def __next__(self): self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, - output_features_ind=self.output_features_ind, + hr_features_ind=self.hr_features_ind, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch else: @@ -1330,10 +1216,9 @@ def __next__(self): batch = self.BATCH_CLASS.get_coarse_batch( high_res, self.s_enhance, - output_features_ind=self.output_features_ind, + hr_features_ind=self.hr_features_ind, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features) + smoothing_ignore=self.smoothing_ignore) self._i += 1 return batch else: @@ -1405,9 +1290,8 @@ def __next__(self): self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, - output_features_ind=self.output_features_ind, - output_features=self.output_features, - training_features=self.training_features, + hr_features_ind=self.hr_features_ind, + features=self.features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore) @@ -1492,8 +1376,8 @@ def __next__(self): batch = self.BATCH_CLASS.get_coarse_batch( high_res, self.s_enhance, - output_features_ind=self.output_features_ind, - training_features=self.training_features, + hr_features_ind=self.hr_features_ind, + features=self.features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, ) diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index 4d27a92da..96a700cbc 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -79,7 +79,7 @@ def make_output( s_enhance=None, t_enhance=None, model_mom1=None, - output_features_ind=None, + hr_features_ind=None, t_enhance_mode='constant', ): """Make custom batch output @@ -100,7 +100,7 @@ def make_output( Temporal enhancement factor model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. t_enhance_mode : str @@ -165,7 +165,7 @@ def make_mask( None by default model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -203,9 +203,8 @@ def get_coarse_batch( t_enhance=1, temporal_coarsening_method='subsample', temporal_enhancing_method='constant', - output_features_ind=None, - output_features=None, - training_features=None, + hr_features_ind=None, + features=None, smoothing=None, smoothing_ignore=None, model_mom1=None, @@ -240,13 +239,12 @@ def get_coarse_batch( between landmarks. linear will linearly interpolate between landmarks to generate the low-res data to remove from the high-res. - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. - output_features : list - List of Generative model output feature names - training_features : list | None - Ordered list of training features input to the generative model + features : list | None + Ordered list of low-resolution training features input to the + generative model smoothing : float | None Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low @@ -277,8 +275,11 @@ def get_coarse_batch( """ low_res = spatial_coarsening(high_res, s_enhance) - if training_features is None: - training_features = [None] * low_res.shape[-1] + if features is None: + features = [None] * low_res.shape[-1] + + if hr_features_ind is None: + hr_features_ind = np.arange(high_res.shape[-1]) if smoothing_ignore is None: smoothing_ignore = [] @@ -289,16 +290,16 @@ def get_coarse_batch( ) low_res = smooth_data( - low_res, training_features, smoothing_ignore, smoothing + low_res, features, smoothing_ignore, smoothing ) - high_res = cls.reduce_features(high_res, output_features_ind) + high_res = high_res[..., hr_features_ind] output = cls.make_output( low_res, high_res, s_enhance, t_enhance, model_mom1, - output_features_ind, + hr_features_ind, temporal_enhancing_method, ) mask = cls.make_mask( @@ -320,7 +321,7 @@ def make_output( s_enhance=None, t_enhance=None, model_mom1=None, - output_features_ind=None, + hr_features_ind=None, t_enhance_mode='constant', ): """Make custom batch output @@ -341,7 +342,7 @@ def make_output( Temporal enhancement factor model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. t_enhance_mode : str @@ -362,7 +363,7 @@ def make_output( enhanced_lr = temporal_simple_enhancing( enhanced_lr, t_enhance=t_enhance, mode=t_enhance_mode ) - enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) + enhanced_lr = enhanced_lr[..., hr_features_ind] return high_res - enhanced_lr @@ -378,7 +379,7 @@ def make_output( s_enhance=None, t_enhance=None, model_mom1=None, - output_features_ind=None, + hr_features_ind=None, t_enhance_mode='constant', ): """Make custom batch output @@ -399,7 +400,7 @@ def make_output( Temporal enhancement factor model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. t_enhance_mode : str @@ -430,7 +431,7 @@ def make_output( s_enhance=None, t_enhance=None, model_mom1=None, - output_features_ind=None, + hr_features_ind=None, t_enhance_mode='constant', ): """Make custom batch output @@ -451,7 +452,7 @@ def make_output( Temporal enhancement factor model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. t_enhance_mode : str @@ -473,7 +474,7 @@ def make_output( s_enhance, t_enhance, model_mom1, - output_features_ind, + hr_features_ind, t_enhance_mode, ) ** 2 @@ -491,7 +492,7 @@ def make_output( s_enhance=None, t_enhance=None, model_mom1=None, - output_features_ind=None, + hr_features_ind=None, t_enhance_mode='constant', ): """Make custom batch output @@ -512,7 +513,7 @@ def make_output( Temporal enhancement factor model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. t_enhance_mode : str @@ -534,7 +535,7 @@ def make_output( enhanced_lr = temporal_simple_enhancing( enhanced_lr, t_enhance=t_enhance, mode=t_enhance_mode ) - enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) + enhanced_lr = enhanced_lr[..., hr_features_ind] return (high_res - enhanced_lr - out) ** 2 @@ -549,7 +550,7 @@ def make_output( s_enhance=None, t_enhance=None, model_mom1=None, - output_features_ind=None, + hr_features_ind=None, t_enhance_mode='constant', ): """Make custom batch output @@ -570,7 +571,7 @@ def make_output( Temporal enhancement factor model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. t_enhance_mode : str @@ -594,7 +595,7 @@ def make_output( s_enhance, t_enhance, model_mom1, - output_features_ind, + hr_features_ind, t_enhance_mode, ) ** 2 @@ -615,8 +616,7 @@ def __init__( t_enhance=1, temporal_coarsening_method='subsample', temporal_enhancing_method='constant', - output_features_ind=None, - output_features=None, + hr_features_ind=None, smoothing=None, smoothing_ignore=None, model_mom1=None, @@ -651,11 +651,9 @@ def __init__( between landmarks. linear will linearly interpolate between landmarks to generate the low-res data to remove from the high-res. - output_features_ind : list | np.ndarray | None + hr_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. - output_features : list - List of Generative model output feature names smoothing : float | None Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low @@ -698,8 +696,7 @@ def __init__( self.temporal_coarsening_method = temporal_coarsening_method self.temporal_enhancing_method = temporal_enhancing_method self._i = 0 - self.output_features_ind = output_features_ind - self.output_features = output_features + self.hr_features_ind = hr_features_ind self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore self.model_mom1 = model_mom1 @@ -724,10 +721,9 @@ def batch_next(self, high_res): t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, temporal_enhancing_method=self.temporal_enhancing_method, - output_features_ind=self.output_features_ind, + hr_features_ind=self.hr_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features, model_mom1=self.model_mom1, s_padding=self.s_padding, t_padding=self.t_padding, @@ -887,7 +883,7 @@ def __init__( self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore or [] self.smoothed_features = [ - f for f in self.training_features if f not in self.smoothing_ignore + f for f in self.lr_features if f not in self.smoothing_ignore ] self._stats_workers = stats_workers self._norm_workers = norm_workers @@ -921,8 +917,7 @@ def __init__( t_enhance=t_enhance, temporal_coarsening_method=temporal_coarsening_method, temporal_enhancing_method=temporal_enhancing_method, - output_features_ind=self.output_features_ind, - output_features=self.output_features, + hr_features_ind=self.hr_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, @@ -969,9 +964,8 @@ def __next__(self): t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, temporal_enhancing_method=self.temporal_enhancing_method, - output_features_ind=self.output_features_ind, - output_features=self.output_features, - training_features=self.training_features, + hr_features_ind=self.hr_features_ind, + features=self.features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, @@ -1008,8 +1002,8 @@ def __next__(self): batch = self.BATCH_CLASS.get_coarse_batch( high_res, self.s_enhance, - output_features_ind=self.output_features_ind, - training_features=self.training_features, + hr_features_ind=self.hr_features_ind, + features=self.features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 1ffb48b2d..414be799d 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -1,6 +1,7 @@ """Base data handling classes. @author: bbenton """ +import copy import logging import os import pickle @@ -67,16 +68,6 @@ class DataHandler(FeatureHandler, InputMixIn, TrainingPrepMixIn): (spatial_1, spatial_2, temporal, features) """ - # list of features / feature name patterns that are input to the generative - # model but are not part of the synthetic output and are not sent to the - # discriminator. These are case-insensitive and follow the Unix shell-style - # wildcard format. - TRAIN_ONLY_FEATURES = ('BVF*', - 'inversemoninobukhovlength_*', - 'RMOL', - 'topography', - ) - def __init__(self, file_paths, features, @@ -96,7 +87,8 @@ def __init__(self, overwrite_cache=False, overwrite_ti_cache=False, load_cached=False, - train_only_features=None, + lr_only_features=tuple(), + hr_exo_features=tuple(), handle_features=None, single_ts_files=None, mask_nan=False, @@ -142,12 +134,12 @@ def __init__(self, Size of spatial and temporal domain used in a single high-res observation for batching raster_file : str | None - File for raster_index array for the corresponding target and shape. - If specified the raster_index will be loaded from the file if it - exists or written to the file if it does not yet exist. If None and - raster_index is not provided raster_index will be calculated - directly. Either need target+shape, raster_file, or raster_index - input. + .txt file for raster_index array for the corresponding target and + shape. If specified the raster_index will be loaded from the file + if it exists or written to the file if it does not yet exist. If + None and raster_index is not provided raster_index will be + calculated directly. Either need target+shape, raster_file, or + raster_index input. raster_index : list List of tuples or slices. Used as an alternative to computing the raster index from target+shape or loading the raster index from @@ -174,10 +166,14 @@ def __init__(self, Whether to overwrite saved time index cache files. load_cached : bool Whether to load data from cache files - train_only_features : list | tuple | None + lr_only_features : list | tuple List of feature names or patt*erns that should only be included in - the training set and not the output. If None (default), this will - default to the class TRAIN_ONLY_FEATURES attribute. + the low-res training set and not the high-res observations. + hr_exo_features : list | tuple + List of feature names or patt*erns that should be included in the + high-resolution observation but not expected to be output from the + generative model. An example is high-res topography that is to be + injected mid-network. handle_features : list | None Optional list of features which are available in the provided data. Providing this eliminates the need for an initial search of @@ -230,8 +226,9 @@ def __init__(self, temporal_slice=temporal_slice) self.file_paths = file_paths - self.features = (features if isinstance(features, - (list, tuple)) else [features]) + self.features = (features if isinstance(features, (list, tuple)) + else [features]) + self.features = copy.deepcopy(self.features) self.val_time_index = None self.max_delta = max_delta self.val_split = val_split @@ -248,7 +245,8 @@ def __init__(self, self.res_kwargs = res_kwargs or {} self._single_ts_files = single_ts_files self._cache_pattern = cache_pattern - self._train_only_features = train_only_features + self._lr_only_features = lr_only_features + self._hr_exo_features = hr_exo_features self._time_chunk_size = time_chunk_size self._handle_features = handle_features self._cache_files = None @@ -257,6 +255,9 @@ def __init__(self, self._raw_features = None self._raw_data = {} self._time_chunks = None + self._means = None + self._stds = None + self._is_normalized = False self.worker_kwargs = worker_kwargs or {} self.max_workers = self.worker_kwargs.get('max_workers', None) self._ti_workers = self.worker_kwargs.get('ti_workers', None) @@ -404,13 +405,6 @@ def attrs(self): desc = handle.attrs return desc - @property - def train_only_features(self): - """Features to use for training only and not output""" - if self._train_only_features is None: - self._train_only_features = self.TRAIN_ONLY_FEATURES - return self._train_only_features - @property def extract_workers(self): """Get upper bound for extract workers based on memory limits. Used to @@ -603,19 +597,86 @@ def raw_features(self): if self._raw_features is None: self._raw_features = self.get_raw_feature_list( self.noncached_features, self.handle_features) + return self._raw_features @property - def output_features(self): - """Get a list of features that should be output by the generative model - corresponding to the features in the high res batch array.""" + def lr_only_features(self): + """List of feature names or patt*erns that should only be included in + the low-res training set and not the high-res observations.""" + if isinstance(self._lr_only_features, str): + self._lr_only_features = [self._lr_only_features] + + elif isinstance(self._lr_only_features, tuple): + self._lr_only_features = list(self._lr_only_features) + + elif self._lr_only_features is None: + self._lr_only_features = [] + + return self._lr_only_features + + @property + def lr_features(self): + """Get a list of low-resolution features. It is assumed that all + features are used in the low-resolution observations. If you want to + use high-res-only features, use the DualDataHandler class.""" + return self.features + + @property + def hr_exo_features(self): + """Get a list of exogenous high-resolution features that are only used + for training e.g., mid-network high-res topo injection. These must come + at the end of the high-res feature set. These can also be input to the + model as low-res features.""" + + if isinstance(self._hr_exo_features, str): + self._hr_exo_features = [self._hr_exo_features] + + elif isinstance(self._hr_exo_features, tuple): + self._hr_exo_features = list(self._hr_exo_features) + + elif self._hr_exo_features is None: + self._hr_exo_features = [] + + if any('*' in fn for fn in self._hr_exo_features): + hr_exo_features = [] + for feature in self.features: + match = any(fnmatch(feature.lower(), pattern.lower()) + for pattern in self._hr_exo_features) + if match: + hr_exo_features.append(feature) + self._hr_exo_features = hr_exo_features + + if len(self._hr_exo_features) > 0: + msg = (f'High-res train-only features "{self._hr_exo_features}" ' + f'do not come at the end of the full high-res feature set: ' + f'{self.features}') + last_feat = self.features[-len(self._hr_exo_features):] + assert list(self._hr_exo_features) == list(last_feat), msg + + return self._hr_exo_features + + @property + def hr_out_features(self): + """Get a list of high-resolution features that are intended to be + output by the GAN. Does not include high-resolution exogenous + features""" + out = [] for feature in self.features: - ignore = any( - fnmatch(feature.lower(), pattern.lower()) - for pattern in self.train_only_features) + lr_only = any(fnmatch(feature.lower(), pattern.lower()) + for pattern in self.lr_only_features) + ignore = lr_only or feature in self.hr_exo_features if not ignore: out.append(feature) + + if len(out) == 0: + msg = (f'It appears that all handler features "{self.features}" ' + 'were specified as `hr_exo_features` or `lr_only_features` ' + 'and therefore there are no output features!') + logger.error(msg) + raise RuntimeError(msg) + return out @property @@ -860,28 +921,71 @@ def get_cache_file_names(self, target, features) - def unnormalize(self, means, stds): - """Remove normalization from stored means and stds""" - self._unnormalize(self.data, self.val_data, means, stds) + @property + def means(self): + """Get the mean values for each feature. - def normalize(self, means, stds): - """Normalize all data features + Returns + ------- + dict + """ + self._get_stats() + return self._means + + @property + def stds(self): + """Get the standard deviation values for each feature. + + Returns + ------- + dict + """ + self._get_stats() + return self._stds + + def _get_stats(self): + if self._means is None or self._stds is None: + msg = (f'DataHandler has {len(self.features)} features ' + f'and mismatched shape of {self.shape}') + assert len(self.features) == self.shape[-1], msg + self._stds = {} + self._means = {} + for idf, fname in enumerate(self.features): + self._means[fname] = np.nanmean(self.data[..., idf]) + self._stds[fname] = np.nanstd(self.data[..., idf]) + + def normalize(self, means=None, stds=None, max_workers=None): + """Normalize all data features. Parameters ---------- - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. If this is None, the self.means attribute will + be used. If this is not None, this DataHandler object means + attribute will be updated. + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. If this is None, the + self.stds attribute will be used. If this is not None, this + DataHandler object stds attribute will be updated. + max_workers : None | int + Max workers to perform normalization. if None, self.norm_workers + will be used """ - max_workers = self.norm_workers - self._normalize(self.data, - self.val_data, - means, - stds, - max_workers=max_workers) + if means is not None: + self._means = means + if stds is not None: + self._stds = stds + + max_workers = max_workers or self.norm_workers + if self._is_normalized: + logger.info('Skipping DataHandler, already normalized') + else: + self._normalize(self.data, + self.val_data, + max_workers=max_workers) + self._is_normalized = True def get_next(self): """Get data for observation using random observation index. Loops diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index db245229a..b040e78cb 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -1,4 +1,5 @@ """Dual data handler class for using separate low_res and high_res datasets""" +import copy import logging import pickle from warnings import warn @@ -37,7 +38,7 @@ def __init__(self, regrid_workers=1, load_cached=True, shuffle_time=False, - s_enhance=15, + s_enhance=1, t_enhance=1, val_split=0.0): """Initialize data handler using hr and lr data handlers for h5 data @@ -90,7 +91,9 @@ def __init__(self, self.hr_time_index = None self.lr_val_time_index = None self.hr_val_time_index = None - self.lr_data = np.zeros(self.shape, dtype=np.float32) + + lr_data_shape = (*self.lr_required_shape, len(self.lr_dh.features)) + self.lr_data = np.zeros(lr_data_shape, dtype=np.float32) if self.try_load and self.load_cached: self.load_cached_data() @@ -158,44 +161,104 @@ def _val_split_check(self): logger.warning(msg) warn(msg) - def normalize(self, means, stdevs): + def _get_stats(self): + """Get mean/stdev stats for HR and LR data handlers""" + self.lr_dh._get_stats() + self.hr_dh._get_stats() + + @property + def means(self): + """Get the mean values for each feature. Mean values from the low-res + data handler are prioritized because these are typically the "input" + features + + Returns + ------- + dict + """ + out = copy.deepcopy(self.hr_dh.means) + out.update(self.lr_dh.means) + return out + + @property + def stds(self): + """Get the standard deviation values for each feature. Mean values from + the low-res data handler are prioritized because these are typically + the "input" features + + Returns + ------- + dict + """ + out = copy.deepcopy(self.hr_dh.stds) + out.update(self.lr_dh.stds) + return out + + def normalize(self, means=None, stds=None, max_workers=None): """Normalize low_res and high_res data Parameters ---------- - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stdevs : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. If this is None, the self.means attribute will + be used. If this is not None, this DataHandler object means + attribute will be updated. + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. If this is None, the + self.stds attribute will be used. If this is not None, this + DataHandler object stds attribute will be updated. + max_workers : None | int + Max workers to perform normalization. if None, self.norm_workers + will be used """ + if means is None: + means = self.means + if stds is None: + stds = self.stds logger.info('Normalizing low resolution data features=' - f'{self.features}') - self._normalize(data=self.lr_data, - val_data=self.lr_val_data, - means=means, - stds=stdevs, - max_workers=self.lr_dh.norm_workers) + f'{self.lr_dh.features}') + self.lr_dh.normalize(means=means, stds=stds, max_workers=max_workers) logger.info('Normalizing high resolution data features=' - f'{self.output_features}') - indices = [self.features.index(f) for f in self.output_features] - self._normalize(data=self.hr_data, - val_data=self.hr_val_data, - means=means[indices], - stds=stdevs[indices], - max_workers=self.hr_dh.norm_workers) + f'{self.hr_dh.features}') + self.hr_dh.normalize(means=means, stds=stds, max_workers=max_workers) @property - def output_features(self): - """Get list of output features. e.g. those that are returned by a - GAN""" - return self.lr_dh.output_features + def features(self): + """Get a list of data features including features from both the lr and + hr data handlers""" + out = list(copy.deepcopy(self.lr_dh.features)) + out += [fn for fn in self.hr_dh.features if fn not in out] + return out @property - def train_only_features(self): + def lr_only_features(self): """Features to use for training only and not output""" - return self.lr_dh.train_only_features + tof = [fn for fn in self.lr_dh.features + if fn not in self.hr_out_features + and fn not in self.hr_exo_features] + return tof + + @property + def lr_features(self): + """Get a list of low-resolution features. All low-resolution features + are used for training.""" + return self.lr_dh.lr_features + + @property + def hr_exo_features(self): + """Get a list of high-resolution features that are only used for + training e.g., mid-network high-res topo injection. These must come at + the end of the high-res feature set.""" + return self.hr_dh.hr_exo_features + + @property + def hr_out_features(self): + """Get a list of high-resolution features that are intended to be + output by the GAN. Does not include high-resolution exogenous features + """ + return self.hr_dh.hr_out_features def _shape_check(self): """Check if hr_handler.shape is divisible by s_enhance. If not take @@ -212,14 +275,14 @@ def _shape_check(self): logger.warning(msg) warn(msg) - self.hr_data = self.hr_dh.data[:self.hr_required_shape[0], :self. - hr_required_shape[1], :self. - hr_required_shape[2]] + self.hr_data = self.hr_dh.data[:self.hr_required_shape[0], + :self.hr_required_shape[1], + :self.hr_required_shape[2]] self.hr_time_index = self.hr_dh.time_index[:self.hr_required_shape[2]] self.lr_time_index = self.lr_dh.time_index[:self.lr_required_shape[2]] assert np.array_equal(self.hr_time_index[::self.t_enhance].values, - self.lr_time_index) + self.lr_time_index.values) def _run_pair_checks(self, hr_handler, lr_handler): """Run sanity checks on high_res and low_res pairs. The handler data @@ -228,9 +291,6 @@ def _run_pair_checks(self, hr_handler, lr_handler): 'hr_handler.val_split and lr_handler.val_split should both be ' 'zero.') assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg - msg = ('Handlers have incompatible number of features. ' - f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == self.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, @@ -301,11 +361,6 @@ def hr_sample_shape(self): """Get hr sample shape""" return self.hr_dh.sample_shape - @property - def features(self): - """Get list of features in each data handler""" - return self.lr_dh.features - @property def data(self): """Get low res data. Same as self.lr_data but used to match property @@ -332,7 +387,7 @@ def lr_required_shape(self): @property def shape(self): """Get low_res shape""" - return (*self.lr_required_shape, len(self.features)) + return (*self.lr_required_shape, len(self.lr_dh.features)) @property def size(self): @@ -393,6 +448,18 @@ def cache_files(self): features=self.lr_dh.features) return cache_files + @property + def noncached_features(self): + """Get list of features needing extraction or derivation""" + if self._noncached_features is None: + self._noncached_features = self.check_cached_features( + self.lr_dh.features, + cache_files=self.cache_files, + overwrite_cache=self.overwrite_cache, + load_cached=self.load_cached, + ) + return self._noncached_features + @property def try_load(self): """Check if we should try to load cached data""" @@ -437,7 +504,7 @@ def get_lr_data(self): logger.info('Caching low resolution data with ' f'shape={self.lr_data.shape}.') self._cache_data(self.lr_data, - features=self.features, + features=self.lr_dh.features, cache_file_paths=self.cache_files, overwrite=self.overwrite_cache) @@ -460,29 +527,33 @@ def get_lr_regridded_data(self): logger.info('Regridding low resolution feature data.') regridder = self.get_regridder() - for f in self.noncached_features: - fidx = self.features.index(f) + fnames = set(self.noncached_features) + fnames = fnames.intersection(set(self.lr_dh.features)) + for fname in fnames: + fidx = self.lr_dh.features.index(fname) tmp = regridder(self.lr_input_data[..., fidx]) tmp = tmp.reshape(self.lr_required_shape) self.lr_data[..., fidx] = tmp if self.load_cached: - for f in self.cached_features: - f_index = self.features.index(f) - logger.info(f'Loading {f} from {self.cache_files[f_index]}') - with open(self.cache_files[f_index], 'rb') as fh: - self.lr_data[..., f_index] = pickle.load(fh) + fnames = set(self.cached_features) + fnames = fnames.intersection(set(self.lr_dh.features)) + for fname in fnames: + fidx = self.lr_dh.features.index(fname) + logger.info(f'Loading {fname} from {self.cache_files[fidx]}') + with open(self.cache_files[fidx], 'rb') as fh: + self.lr_data[..., fidx] = pickle.load(fh) for fidx in range(self.lr_data.shape[-1]): nan_perc = (100 * np.isnan(self.lr_data[..., fidx]).sum() / self.lr_data[..., fidx].size) if nan_perc > 0: - msg = (f'{self.features[fidx]} data has {nan_perc:.3f}% NaN ' - 'values!') + msg = (f'{self.lr_dh.features[fidx]} data has ' + f'{nan_perc:.3f}% NaN values!') logger.warning(msg) warn(msg) - msg = (f'Doing nn nan fill on low res {self.features[fidx]} ' - 'data.') + msg = (f'Doing nn nan fill on low res ' + f'{self.lr_dh.features[fidx]} data.') logger.info(msg) self.lr_data[..., fidx] = nn_fill_array( self.lr_data[..., fidx]) @@ -508,7 +579,7 @@ def get_next(self): hr_obs_idx += [slice(s.start * self.t_enhance, s.stop * self.t_enhance) for s in lr_obs_idx[2:-1]] - hr_obs_idx.append(np.arange(len(self.output_features))) + hr_obs_idx.append(np.arange(len(self.hr_dh.features))) hr_obs_idx = tuple(hr_obs_idx) self.current_obs_index = { 'hr_index': hr_obs_idx, diff --git a/sup3r/preprocessing/data_handling/h5_data_handling.py b/sup3r/preprocessing/data_handling/h5_data_handling.py index fb47ccd0b..fca438c0c 100644 --- a/sup3r/preprocessing/data_handling/h5_data_handling.py +++ b/sup3r/preprocessing/data_handling/h5_data_handling.py @@ -165,7 +165,7 @@ def get_raster_index(self): msg = ('Must provide raster file or shape + target to get ' 'raster index') assert check, msg - logger.debug('Calculating raster index from WTK file ' + logger.debug('Calculating raster index from .h5 file ' f'for shape {self.grid_shape} and target ' f'{self.target}') handle = self.source_handler(self.file_paths[0]) @@ -196,15 +196,6 @@ class DataHandlerH5WindCC(DataHandlerH5): # the handler from rex to open h5 data. REX_HANDLER = MultiFileWindX - # list of features / feature name patterns that are input to the generative - # model but are not part of the synthetic output and are not sent to the - # discriminator. These are case-insensitive and follow the Unix shell-style - # wildcard format. - TRAIN_ONLY_FEATURES = ('temperature_max_*m', 'temperature_min_*m', - 'relativehumidity_max_*m', - 'relativehumidity_min_*m', - ) - def __init__(self, *args, **kwargs): """ Parameters @@ -411,12 +402,6 @@ class DataHandlerH5SolarCC(DataHandlerH5WindCC): # the handler from rex to open h5 data. REX_HANDLER = MultiFileNSRDBX - # list of features / feature name patterns that are input to the generative - # model but are not part of the synthetic output and are not sent to the - # discriminator. These are case-insensitive and follow the Unix shell-style - # wildcard format. - TRAIN_ONLY_FEATURES = ('U*', 'V*', 'topography') - def __init__(self, *args, **kwargs): """ Parameters diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index acc245173..0d071a082 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -239,8 +239,8 @@ def _cache_data(self, data, features, cache_file_paths, overwrite=False): overwrite : bool Whether to overwrite exisiting files. """ - os.makedirs(os.path.dirname(cache_file_paths[0]), exist_ok=True) for i, fp in enumerate(cache_file_paths): + os.makedirs(os.path.dirname(fp), exist_ok=True) if not os.path.exists(fp) or overwrite: if overwrite and os.path.exists(fp): logger.info(f'Overwriting {features[i]} with shape ' @@ -392,9 +392,8 @@ def _load_cached_data(self, data, cache_files, features, max_workers=None): features, max_workers=max_workers) - @classmethod - def check_cached_features(cls, - features, + @staticmethod + def check_cached_features(features, cache_files=None, overwrite_cache=False, load_cached=False): @@ -924,6 +923,8 @@ class TrainingPrepMixIn: def __init__(self): """Initialize common attributes""" self.features = None + self.means = None + self.stds = None @classmethod def _split_data_indices(cls, @@ -991,29 +992,6 @@ def _get_observation_index(self, data, sample_shape): temporal_slice = uniform_time_sampler(data, sample_shape[2]) return (*spatial_slice, temporal_slice, np.arange(data.shape[-1])) - @classmethod - def _unnormalize(cls, data, val_data, means, stds): - """Remove normalization from stored means and stds - - Parameters - ---------- - data : np.ndarray - Array of training data. - (spatial_1, spatial_2, temporal, n_features) - val_data : np.ndarray - Array of validation data. - (spatial_1, spatial_2, temporal, n_features) - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - """ - val_data = (val_data * stds) + means - data = (data * stds) + means - return data, val_data - def _normalize_data(self, data, val_data, feature_index, mean, std): """Normalize data with initialized mean and standard deviation for a specific feature @@ -1036,6 +1014,7 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): if val_data is not None: val_data[..., feature_index] -= mean + data[..., feature_index] -= mean if std > 0: @@ -1048,9 +1027,10 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): logger.warning(msg) warnings.warn(msg) - logger.info(f'Finished normalizing {self.features[feature_index]}.') + logger.debug(f'Finished normalizing {self.features[feature_index]} ' + f'with mean {mean:.3e} and std {std:.3e}.') - def _normalize(self, data, val_data, means, stds, max_workers=None): + def _normalize(self, data, val_data, max_workers=None): """Normalize all data features Parameters @@ -1061,74 +1041,38 @@ def _normalize(self, data, val_data, means, stds, max_workers=None): val_data : np.ndarray Array of validation data. (spatial_1, spatial_2, temporal, n_features) - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features max_workers : int | None Number of workers to use in thread pool for nomalization. """ - msg = f'Received {len(means)} means for {data.shape[-1]} features' - assert len(means) == data.shape[-1], msg - msg = f'Received {len(stds)} stds for {data.shape[-1]} features' - assert len(stds) == data.shape[-1], msg - logger.info(f'Normalizing {data.shape[-1]} features.') - if max_workers == 1: - for i in range(data.shape[-1]): - self._normalize_data(data, val_data, i, means[i], stds[i]) - else: - self.parallel_normalization(data, - val_data, - means, - stds, - max_workers=max_workers) - - def parallel_normalization(self, - data, - val_data, - means, - stds, - max_workers=None): - """Run normalization of features in parallel - - Parameters - ---------- - data : np.ndarray - Array of training data. - (spatial_1, spatial_2, temporal, n_features) - val_data : np.ndarray - Array of validation data. - (spatial_1, spatial_2, temporal, n_features) - means : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - stds : np.ndarray - dimensions (features) - array of means for all features with same ordering as data features - max_workers : int | None - Max number of workers to use for normalizing features - """ - with ThreadPoolExecutor(max_workers=max_workers) as exe: - futures = {} - now = dt.now() - for i in range(data.shape[-1]): - future = exe.submit(self._normalize_data, data, val_data, i, - means[i], stds[i]) - futures[future] = i + msg1 = (f'Not all feature names {self.features} were found in ' + f'self.means: {list(self.means.keys())}') + msg2 = (f'Not all feature names {self.features} were found in ' + f'self.stds: {list(self.stds.keys())}') + assert all(fn in self.means for fn in self.features), msg1 + assert all(fn in self.stds for fn in self.features), msg2 - logger.info(f'Started normalizing {data.shape[-1]} features ' - f'in {dt.now() - now}.') + logger.info(f'Normalizing {data.shape[-1]} features: {self.features}') - for i, future in enumerate(as_completed(futures)): - try: - future.result() - except Exception as e: - msg = ('Error while normalizing future number ' - f'{futures[future]}.') - logger.exception(msg) - raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of {data.shape[-1]} features ' - 'normalized.') + if max_workers == 1: + for idf, feature in enumerate(self.features): + self._normalize_data(data, val_data, idf, self.means[feature], + self.stds[feature]) + else: + with ThreadPoolExecutor(max_workers=max_workers) as exe: + futures = [] + for idf, feature in enumerate(self.features): + future = exe.submit(self._normalize_data, + data, val_data, idf, + self.means[feature], + self.stds[feature]) + futures.append(future) + + for future in as_completed(futures): + try: + future.result() + except Exception as e: + msg = ('Error while normalizing future number ' + f'{futures[future]}.') + logger.exception(msg) + raise RuntimeError(msg) from e diff --git a/sup3r/preprocessing/dual_batch_handling.py b/sup3r/preprocessing/dual_batch_handling.py index 3b0225176..60f014a57 100644 --- a/sup3r/preprocessing/dual_batch_handling.py +++ b/sup3r/preprocessing/dual_batch_handling.py @@ -125,8 +125,7 @@ def __next__(self): high_res = high_res[..., 0, :] low_res = low_res[..., 0, :] - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) + high_res = high_res[..., self.hr_features_ind] batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 return batch @@ -148,7 +147,7 @@ def lr_features(self): @property def hr_features(self): """Features in high res batch.""" - return self.data_handlers[0].lr_dh.output_features + return self.data_handlers[0].hr_dh.features @property def hr_sample_shape(self): @@ -190,7 +189,8 @@ def __next__(self): dtype=np.float32) for i in range(self.batch_size): - high_res[i, ...], low_res[i, ...] = handler.get_next() + hr_sample, lr_sample = handler.get_next() + high_res[i, ...], low_res[i, ...] = hr_sample, lr_sample self.current_batch_indices.append(handler.current_obs_index) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) diff --git a/sup3r/preprocessing/wind_conditional_moment_batch_handling.py b/sup3r/preprocessing/wind_conditional_moment_batch_handling.py index b55269d02..e947f0741 100644 --- a/sup3r/preprocessing/wind_conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/wind_conditional_moment_batch_handling.py @@ -75,7 +75,7 @@ def make_output(low_res, high_res, enhanced_lr = temporal_simple_enhancing(enhanced_lr, t_enhance=t_enhance, mode=t_enhance_mode) - enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) + enhanced_lr = enhanced_lr[..., output_features_ind] enhanced_lr[..., -1] = high_res[..., -1] return high_res - enhanced_lr @@ -187,7 +187,7 @@ def make_output(low_res, high_res, enhanced_lr = temporal_simple_enhancing(enhanced_lr, t_enhance=t_enhance, mode=t_enhance_mode) - enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) + enhanced_lr = enhanced_lr[..., output_features_ind] enhanced_lr[..., -1] = 0.0 return (high_res - enhanced_lr - out)**2 diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index 3e015b3dc..454c16cbd 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -335,8 +335,10 @@ def __call__(self, data): Parameters ---------- data : ndarray - Spatiotemporal data to regrid to target_meta - (spatial_1, spatial_2, temporal) + Spatiotemporal data to regrid to target_meta. Data can be flattened + in the spatial dimension to match the target_meta or be in a 2D + spatial grid, e.g.: + (spatial, temporal) or (spatial_1, spatial_2, temporal) Returns ------- diff --git a/tests/bias/test_bias_correction.py b/tests/bias/test_bias_correction.py index 8f7d76fa9..0f6c022ad 100644 --- a/tests/bias/test_bias_correction.py +++ b/tests/bias/test_bias_correction.py @@ -314,8 +314,8 @@ def test_fwp_integration(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['training_features'] = features - model.meta['output_features'] = features + model.meta['lr_features'] = features + model.meta['hr_out_features'] = features model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 diff --git a/tests/data_handling/test_data_handling_h5.py b/tests/data_handling/test_data_handling_h5.py index 35df2a741..d2aeb5df9 100644 --- a/tests/data_handling/test_data_handling_h5.py +++ b/tests/data_handling/test_data_handling_h5.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- """pytests for data handling""" - +import json import os -import pickle import tempfile import matplotlib.pyplot as plt @@ -30,6 +29,7 @@ val_split = 0.2 dh_kwargs = {'target': target, 'shape': shape, 'max_delta': 20, 'sample_shape': sample_shape, + 'lr_only_features': ('BVF*m', 'topography',), 'temporal_slice': slice(None, None, 1), 'worker_kwargs': {'max_workers': 1}} bh_kwargs = {'batch_size': 8, 'n_batches': 20, @@ -87,7 +87,7 @@ def test_topography(): topo_idx = data_handler.features.index('topography') assert np.allclose(topo, data_handler.data[..., 0, topo_idx]) st_batch_handler = BatchHandler([data_handler], **bh_kwargs) - assert data_handler.output_features == features[:2] + assert data_handler.hr_out_features == features[:2] assert data_handler.data.shape[-1] == len(features) for batch in st_batch_handler: @@ -197,16 +197,16 @@ def test_raster_index_caching(): def test_normalization_input(): """Test correct normalization input""" - means = np.random.rand(len(features)) - stds = np.random.rand(len(features)) + means = {f: 10 for f in features} + stds = {f: 20 for f in features} data_handlers = [] for input_file in input_files: data_handler = DataHandler(input_file, features, **dh_kwargs) data_handlers.append(data_handler) batch_handler = BatchHandler(data_handlers, means=means, stds=stds, **bh_kwargs) - assert np.array_equal(batch_handler.stds, stds) - assert np.array_equal(batch_handler.means, means) + assert all(batch_handler.means[f] == means[f] for f in features) + assert all(batch_handler.stds[f] == stds[f] for f in features) def test_stats_caching(): @@ -218,20 +218,20 @@ def test_stats_caching(): data_handlers.append(data_handler) with tempfile.TemporaryDirectory() as td: - means_file = os.path.join(td, 'means.pkl') - stdevs_file = os.path.join(td, 'stdevs.pkl') + means_file = os.path.join(td, 'means.json') + stdevs_file = os.path.join(td, 'stds.json') batch_handler = BatchHandler(data_handlers, stdevs_file=stdevs_file, means_file=means_file, **bh_kwargs) assert os.path.exists(means_file) assert os.path.exists(stdevs_file) - with open(means_file, 'rb') as fh: - means = pickle.load(fh) - with open(stdevs_file, 'rb') as fh: - stdevs = pickle.load(fh) + with open(means_file, 'r') as fh: + means = json.load(fh) + with open(stdevs_file, 'r') as fh: + stds = json.load(fh) - assert np.array_equal(means, batch_handler.means) - assert np.array_equal(stdevs, batch_handler.stds) + assert all(batch_handler.means[f] == means[f] for f in features) + assert all(batch_handler.stds[f] == stds[f] for f in features) stacked_data = np.concatenate([d.data for d in batch_handler.data_handlers], axis=2) @@ -241,8 +241,8 @@ def test_stats_caching(): if std == 0: std = 1 mean = np.mean(stacked_data[..., i]) - assert np.allclose(std, 1, atol=1e-3) - assert np.allclose(mean, 0, atol=1e-3) + assert np.allclose(std, 1, atol=1e-2), str(std) + assert np.allclose(mean, 0, atol=1e-5), str(mean) def test_unequal_size_normalization(): @@ -264,8 +264,8 @@ def test_unequal_size_normalization(): if std == 0: std = 1 mean = np.mean(stacked_data[..., i]) - assert np.allclose(std, 1, atol=1e-3) - assert np.allclose(mean, 0, atol=1e-3) + assert np.allclose(std, 1, atol=2e-2), str(std) + assert np.allclose(mean, 0, atol=1e-5), str(mean) def test_normalization(): @@ -284,8 +284,8 @@ def test_normalization(): if std == 0: std = 1 mean = np.mean(stacked_data[..., i]) - assert np.allclose(std, 1, atol=1e-3) - assert np.allclose(mean, 0, atol=1e-3) + assert np.allclose(std, 1, atol=1e-2), str(std) + assert np.allclose(mean, 0, atol=1e-5), str(mean) def test_spatiotemporal_normalization(): @@ -304,8 +304,8 @@ def test_spatiotemporal_normalization(): if std == 0: std = 1 mean = np.mean(stacked_data[..., i]) - assert np.allclose(std, 1, atol=1e-3) - assert np.allclose(mean, 0, atol=1e-3) + assert np.allclose(std, 1, atol=1e-2), str(std) + assert np.allclose(mean, 0, atol=1e-5), str(mean) def test_data_extraction(): @@ -658,3 +658,78 @@ def test_solar_spatial_h5(): assert not np.isnan(batch.high_res).any() assert batch.low_res.shape == (8, 2, 2, 1) assert batch.high_res.shape == (8, 10, 10, 1) + + +def test_lr_only_features(): + """Test using BVF as a low-resolution only feature that should be dropped + from the high-res observations.""" + dh_kwargs_new = dh_kwargs.copy() + dh_kwargs_new["sample_shape"] = sample_shape + dh_kwargs_new["lr_only_features"] = 'BVF2*' + data_handler = DataHandler(input_files[0], features, **dh_kwargs_new) + + bh_kwargs_new = bh_kwargs.copy() + bh_kwargs_new['norm'] = False + batch_handler = BatchHandler(data_handler, **bh_kwargs_new) + + for batch in batch_handler: + assert batch.low_res.shape[-1] == 3 + assert batch.high_res.shape[-1] == 2 + + for iobs, data_ind in enumerate(batch_handler.current_batch_indices): + truth = data_handler.data[data_ind] + np.allclose(truth[..., 0:2], batch.high_res[iobs]) + truth = utilities.spatial_coarsening(truth, s_enhance=s_enhance, + obs_axis=False) + np.allclose(truth[..., ::t_enhance, :], batch.low_res[iobs]) + + +def test_hr_exo_features(): + """Test using BVF as a high-res exogenous feature. For the single data + handler, this isnt supposed to do anything because the feature is still + assumed to be in the low-res.""" + dh_kwargs_new = dh_kwargs.copy() + dh_kwargs_new["sample_shape"] = sample_shape + dh_kwargs_new["hr_exo_features"] = 'BVF2*' + data_handler = DataHandler(input_files[0], features, **dh_kwargs_new) + assert data_handler.hr_exo_features == ['BVF2_200m'] + + bh_kwargs_new = bh_kwargs.copy() + bh_kwargs_new['norm'] = False + batch_handler = BatchHandler(data_handler, **bh_kwargs_new) + + for batch in batch_handler: + assert batch.low_res.shape[-1] == 3 + assert batch.high_res.shape[-1] == 3 + + for iobs, data_ind in enumerate(batch_handler.current_batch_indices): + truth = data_handler.data[data_ind] + np.allclose(truth, batch.high_res[iobs]) + truth = utilities.spatial_coarsening(truth, s_enhance=s_enhance, + obs_axis=False) + np.allclose(truth[..., ::t_enhance, :], batch.low_res[iobs]) + + +@pytest.mark.parametrize(['features', 'lr_only_features', 'hr_exo_features'], + [(['V_100m'], ['V_100m'], []), + (['U_100m'], ['V_100m'], ['V_100m']), + (['U_100m'], [], ['U_100m']), + (['U_100m', 'V_100m'], [], ['U_100m']), + (['U_100m', 'V_100m'], [], ['V_100m', 'U_100m'])]) +def test_feature_errors(features, lr_only_features, hr_exo_features): + """Each of these feature combinations should raise an error due to no + features left in hr output or bad ordering""" + handler = DataHandler(input_files[0], + features, + lr_only_features=lr_only_features, + hr_exo_features=hr_exo_features, + target=target, + shape=(20, 20), + sample_shape=(5, 5, 4), + temporal_slice=slice(None, None, 1), + worker_kwargs=dict(max_workers=1), + ) + with pytest.raises(Exception): + _ = handler.lr_features + _ = handler.hr_out_features + _ = handler.hr_exo_features diff --git a/tests/data_handling/test_data_handling_h5_cc.py b/tests/data_handling/test_data_handling_h5_cc.py index e7488a82a..62eab08b1 100644 --- a/tests/data_handling/test_data_handling_h5_cc.py +++ b/tests/data_handling/test_data_handling_h5_cc.py @@ -259,25 +259,20 @@ def test_solar_batch_nan_stats(): handler = DataHandlerH5SolarCC(INPUT_FILE_S, FEATURES_S, **dh_kwargs) - true_means = [np.nanmean(handler.data[..., 0])] - true_stdevs = [np.nanstd(handler.data[..., 0])] + true_csr_mean = np.nanmean(handler.data[..., 0]) + true_csr_stdev = np.nanstd(handler.data[..., 0]) - orig_daily_means = [] - orig_daily_stdevs = [] - for f in range(handler.daily_data.shape[-1]): - orig_daily_means.append(handler.daily_data[..., f].mean()) - orig_daily_stdevs.append(handler.daily_data[..., f].std()) + orig_daily_mean = handler.daily_data[..., 0].mean() batcher = BatchHandlerCC([handler], batch_size=1, n_batches=10, s_enhance=1, sub_daily_shape=9) - assert np.allclose(true_means, batcher.means) - assert np.allclose(true_stdevs, batcher.stds) + assert batcher.means[FEATURES_S[0]] == true_csr_mean + assert batcher.stds[FEATURES_S[0]] == true_csr_stdev - # make sure the daily means were also normalized - for f in range(handler.daily_data.shape[-1]): - new = (orig_daily_means[f] - true_means[f]) / true_stdevs[f] - assert np.allclose(new, handler.daily_data[..., f].mean(), atol=1e-4) + # make sure the daily means were also normalized by same values + new = (orig_daily_mean - true_csr_mean) / true_csr_stdev + assert np.allclose(new, handler.daily_data[..., 0].mean(), atol=1e-4) handler1 = DataHandlerH5SolarCC(INPUT_FILE_S, FEATURES_S, **dh_kwargs) @@ -288,8 +283,8 @@ def test_solar_batch_nan_stats(): batcher = BatchHandlerCC([handler1, handler2], batch_size=1, n_batches=10, s_enhance=1, sub_daily_shape=9) - assert np.allclose(true_means, batcher.means) - assert np.allclose(true_stdevs, batcher.stds) + assert np.allclose(true_csr_mean, batcher.means[FEATURES_S[0]]) + assert np.allclose(true_csr_stdev, batcher.stds[FEATURES_S[0]]) def test_solar_val_data(): @@ -392,6 +387,7 @@ def test_solar_multi_day_coarse_data(): # run another test with u/v on low res side but not high res features = ['clearsky_ratio', 'u', 'v', 'ghi', 'clearsky_ghi'] + dh_kwargs_new['lr_only_features'] = ['u', 'v'] handler = DataHandlerH5SolarCC(INPUT_FILE_S, features, **dh_kwargs_new) @@ -508,6 +504,7 @@ def test_surf_min_max_vars(): dh_kwargs_new['sample_shape'] = (20, 20, 72) dh_kwargs_new['val_split'] = 0 dh_kwargs_new['temporal_slice'] = slice(None, None, 1) + dh_kwargs_new['lr_only_features'] = ['*_min_*', '*_max_*'] handler = DataHandlerH5WindCC(INPUT_FILE_SURF, surf_features, **dh_kwargs_new) diff --git a/tests/data_handling/test_data_handling_nc.py b/tests/data_handling/test_data_handling_nc.py index 110c3016e..5e90d4010 100644 --- a/tests/data_handling/test_data_handling_nc.py +++ b/tests/data_handling/test_data_handling_nc.py @@ -29,6 +29,7 @@ dh_kwargs = dict(target=target, shape=shape, max_delta=20, + lr_only_features=('BVF*m', 'topography',), sample_shape=sample_shape, temporal_slice=slice(None, None, 1), worker_kwargs=dict(max_workers=1), @@ -54,7 +55,7 @@ def test_topography(): topo_idx = data_handler.features.index('topography') assert np.allclose(topo, data_handler.data[..., :, topo_idx]) st_batch_handler = BatchHandler([data_handler], **bh_kwargs) - assert data_handler.output_features == features[:2] + assert data_handler.hr_out_features == features[:2] assert data_handler.data.shape[-1] == len(features) for batch in st_batch_handler: @@ -141,6 +142,7 @@ def test_spatiotemporal_batch_caching(sample_shape): cache_pattern = os.path.join(td, 'cache_') dh_kwargs_new = dh_kwargs.copy() dh_kwargs_new['sample_shape'] = sample_shape + dh_kwargs_new['lr_only_features'] = ['BVF*'] data_handler = DataHandler(input_files, features, cache_pattern=cache_pattern, **dh_kwargs_new) @@ -248,16 +250,16 @@ def test_raster_index_caching(): def test_normalization_input(): """Test correct normalization input""" - means = np.random.rand(len(features)) - stds = np.random.rand(len(features)) + means = {f: 10 for f in features} + stds = {f: 20 for f in features} with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) data_handler = DataHandler(input_files, features, **dh_kwargs) batch_handler = BatchHandler([data_handler], means=means, stds=stds, **bh_kwargs) - assert np.array_equal(batch_handler.stds, stds) - assert np.array_equal(batch_handler.means, means) + assert all(batch_handler.means[f] == means[f] for f in features) + assert all(batch_handler.stds[f] == stds[f] for f in features) def test_normalization(): @@ -397,6 +399,7 @@ def test_spatiotemporal_batch_observations(sample_shape): input_files = make_fake_nc_files(td, INPUT_FILE, 8) dh_kwargs_new = dh_kwargs.copy() dh_kwargs_new['sample_shape'] = sample_shape + dh_kwargs_new['lr_only_features'] = 'BVF*' data_handler = DataHandler(input_files, features, **dh_kwargs_new) batch_handler = BatchHandler([data_handler], **bh_kwargs) diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 1a105c3c1..998d6aca5 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- """Test the basic training of super resolution GAN""" +import copy import os import tempfile import matplotlib.pyplot as plt import numpy as np from rex import init_logger +import pytest from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling.dual_data_handling import ( @@ -28,7 +30,7 @@ def test_dual_data_handler(log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), - plot=True): + plot=False): """Test basic spatial model training with only gen content loss.""" if log: init_logger('sup3r', log_level='DEBUG') @@ -258,7 +260,7 @@ def test_st_dual_batch_handler(log=False, def test_spatial_dual_batch_handler(log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), - plot=True): + plot=False): """Test spatial dual batch handler.""" if log: init_logger('sup3r', log_level='DEBUG') @@ -425,29 +427,76 @@ def test_normalization(log=False, t_enhance=t_enhance, val_split=0.1) - means = [ - np.nanmean(dual_handler.lr_data[..., i]) - for i in range(dual_handler.lr_data.shape[-1]) - ] - stdevs = [ - np.nanstd(dual_handler.lr_data[..., i] - means[i]) - for i in range(dual_handler.lr_data.shape[-1]) - ] + means = copy.deepcopy(lr_handler.means) + stdevs = copy.deepcopy(lr_handler.stds) batch_handler = DualBatchHandler([dual_handler], batch_size=2, s_enhance=s_enhance, t_enhance=t_enhance, n_batches=10) - assert np.allclose(batch_handler.means, means) - assert np.allclose(batch_handler.stds, stdevs) - stacked_data = np.concatenate( - [d.data for d in batch_handler.data_handlers], axis=2) - - for i in range(len(FEATURES)): - std = np.std(stacked_data[..., i]) - if std == 0: - std = 1 - mean = np.mean(stacked_data[..., i]) - assert np.allclose(std, 1, atol=1e-3) - assert np.allclose(mean, 0, atol=1e-3) + + assert all(means[k] == v for k, v in batch_handler.means.items()) + assert all(stdevs[k] == v for k, v in batch_handler.stds.items()) + + # normalization stats retrieved from LR data before re-gridding + for idf in range(lr_handler.shape[-1]): + std = lr_handler.data[..., idf].std() + mean = lr_handler.data[..., idf].mean() + assert np.allclose(std, 1, atol=1e-3), str(std) + assert np.allclose(mean, 0, atol=1e-3), str(mean) + + +@pytest.mark.parametrize(['lr_features', 'hr_features', 'hr_exo_features'], + [(['U_100m'], ['U_100m', 'V_100m'], ['V_100m']), + (['U_100m'], ['U_100m', 'V_100m'], ('V_100m',)), + (['U_100m'], ['V_100m', 'BVF2_200m'], ['BVF2_200m']), + (['U_100m'], ('V_100m', 'BVF2_200m'), ['BVF2_200m']), + (['U_100m'], ['V_100m', 'BVF2_200m'], [])]) +def test_mixed_lr_hr_features(lr_features, hr_features, hr_exo_features): + """Test weird mixes of low-res and high-res features that should work with + the dual dh""" + lr_handler = DataHandlerNC(FP_ERA, + lr_features, + sample_shape=(5, 5, 4), + temporal_slice=slice(None, None, 1), + worker_kwargs=dict(max_workers=1), + ) + hr_handler = DataHandlerH5(FP_WTK, + hr_features, + hr_exo_features=hr_exo_features, + target=TARGET_COORD, + shape=(20, 20), + sample_shape=(5, 5, 4), + temporal_slice=slice(None, None, 1), + worker_kwargs=dict(max_workers=1), + ) + + dual_handler = DualDataHandler(hr_handler, + lr_handler, + s_enhance=1, + t_enhance=1, + val_split=0.0) + + batch_handler = DualBatchHandler(dual_handler, batch_size=2, + s_enhance=1, t_enhance=1, + n_batches=10, + worker_kwargs={'max_workers': 2}) + + n_hr_features = (len(batch_handler.hr_out_features) + + len(batch_handler.hr_exo_features)) + hr_only_features = [fn for fn in hr_features if fn not in lr_features] + hr_out_true = [fn for fn in hr_features if fn not in hr_exo_features] + assert batch_handler.features == lr_features + hr_only_features + assert batch_handler.lr_features == list(lr_features) + assert batch_handler.hr_exo_features == list(hr_exo_features) + assert batch_handler.hr_out_features == list(hr_out_true) + + for batch in batch_handler: + assert batch.high_res.shape[-1] == n_hr_features + assert batch.low_res.shape[-1] == len(batch_handler.lr_features) + + if batch_handler.lr_features == lr_features + hr_only_features: + assert np.allclose(batch.low_res, batch.high_res) + elif batch_handler.lr_features != lr_features + hr_only_features: + assert not np.allclose(batch.low_res, batch.high_res) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 04f99081c..9fda30f4a 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -53,8 +53,8 @@ def test_fwp_nc_cc(log=False): features = ['U_100m', 'V_100m'] target = (13.67, 125.0) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) - model.meta['training_features'] = features - model.meta['output_features'] = features + model.meta['lr_features'] = features + model.meta['hr_out_features'] = features model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 with tempfile.TemporaryDirectory() as td: @@ -110,8 +110,8 @@ def test_fwp_single_ts_vs_multi_ts_input_files(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = ['U_100m', 'V_100m'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 1 with tempfile.TemporaryDirectory() as td: @@ -172,7 +172,7 @@ def test_fwp_single_ts_vs_multi_ts_input_files(): **kwargs) as single_ts: with xr.open_mfdataset(multi_ts_handler.out_files, **kwargs) as multi_ts: - for feat in model.meta['output_features']: + for feat in model.meta['hr_out_features']: assert np.array_equal(single_ts[feat].values, multi_ts[feat].values) @@ -186,8 +186,8 @@ def test_fwp_spatial_only(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = ['U_100m', 'V_100m'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 1 with tempfile.TemporaryDirectory() as td: @@ -241,8 +241,8 @@ def test_fwp_nc(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = ['U_100m', 'V_100m'] model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 with tempfile.TemporaryDirectory() as td: @@ -299,8 +299,8 @@ def test_fwp_temporal_slice(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, 2))) - model.meta['training_features'] = ['U_100m', 'V_100m'] - model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['lr_features'] = ['U_100m', 'V_100m'] + model.meta['hr_out_features'] = ['U_100m', 'V_100m'] model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 with tempfile.TemporaryDirectory() as td: @@ -353,7 +353,7 @@ def test_fwp_temporal_slice(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert isinstance(gan_meta, dict) - assert gan_meta['training_features'] == ['U_100m', 'V_100m'] + assert gan_meta['lr_features'] == ['U_100m', 'V_100m'] def test_fwp_handler(): @@ -365,8 +365,8 @@ def test_fwp_handler(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:-1] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:-1] model.meta['s_enhance'] = s_enhance model.meta['t_enhance'] = t_enhance _ = model.generate(np.ones((4, 10, 10, 12, 3))) @@ -418,8 +418,8 @@ def test_fwp_chunking(log=False, plot=False): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:-1] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:-1] model.meta['s_enhance'] = s_enhance model.meta['t_enhance'] = t_enhance _ = model.generate(np.ones((4, 10, 10, 12, 3))) @@ -447,7 +447,7 @@ def test_fwp_chunking(log=False, plot=False): worker_kwargs=dict(max_workers=1))) data_chunked = np.zeros( (shape[0] * s_enhance, shape[1] * s_enhance, - len(input_files) * t_enhance, len(model.output_features))) + len(input_files) * t_enhance, len(model.hr_out_features))) handlerNC = DataHandlerNC(input_files, FEATURES, target=target, @@ -495,11 +495,11 @@ def test_fwp_chunking(log=False, plot=False): fig.colorbar(nc, ax=ax1, shrink=0.6, - label=f'{model.output_features[ifeature]}') + label=f'{model.hr_out_features[ifeature]}') fig.colorbar(ch, ax=ax2, shrink=0.6, - label=f'{model.output_features[ifeature]}') + label=f'{model.hr_out_features[ifeature]}') fig.colorbar(diff, ax=ax3, shrink=0.6, label='Difference') plt.savefig(f'./chunk_vs_nochunk_{ifeature}.png') plt.close() @@ -517,8 +517,8 @@ def test_fwp_nochunking(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:-1] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:-1] model.meta['s_enhance'] = s_enhance model.meta['t_enhance'] = t_enhance _ = model.generate(np.ones((4, 10, 10, 12, 3))) @@ -569,8 +569,8 @@ def test_fwp_multi_step_model(): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s_model.meta['training_features'] = ['U_100m', 'V_100m'] - s_model.meta['output_features'] = ['U_100m', 'V_100m'] + s_model.meta['lr_features'] = ['U_100m', 'V_100m'] + s_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] assert s_model.s_enhance == 2 assert s_model.t_enhance == 1 _ = s_model.generate(np.ones((4, 10, 10, 2))) @@ -578,8 +578,8 @@ def test_fwp_multi_step_model(): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['lr_features'] = ['U_100m', 'V_100m'] + st_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] assert st_model.s_enhance == 3 assert st_model.t_enhance == 4 _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) @@ -649,7 +649,7 @@ def test_fwp_multi_step_model(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 2 # two step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m'] + assert gan_meta[0]['lr_features'] == ['U_100m', 'V_100m'] def test_slicing_no_pad(log=False): @@ -667,8 +667,8 @@ def test_slicing_no_pad(log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) features = ['U_100m', 'V_100m'] - st_model.meta['training_features'] = features - st_model.meta['output_features'] = features + st_model.meta['lr_features'] = features + st_model.meta['hr_out_features'] = features st_model.meta['s_enhance'] = s_enhance st_model.meta['t_enhance'] = t_enhance _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) @@ -729,8 +729,8 @@ def test_slicing_pad(log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) features = ['U_100m', 'V_100m'] - st_model.meta['training_features'] = features - st_model.meta['output_features'] = features + st_model.meta['lr_features'] = features + st_model.meta['hr_out_features'] = features st_model.meta['s_enhance'] = s_enhance st_model.meta['t_enhance'] = t_enhance _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index a282ccd28..c6bfd4d32 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -40,8 +40,8 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '48km', @@ -49,8 +49,8 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '24km', @@ -60,8 +60,8 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['lr_features'] = ['U_100m', 'V_100m'] + st_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 st_model.meta['input_resolution'] = {'spatial': '12km', @@ -149,7 +149,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == [ + assert gan_meta[0]['lr_features'] == [ 'U_100m', 'V_100m', 'topography' ] @@ -161,8 +161,8 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '16km', @@ -170,8 +170,8 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '8km', @@ -245,7 +245,7 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 2 # two step model - assert gan_meta[0]['training_features'] == [ + assert gan_meta[0]['lr_features'] == [ 'U_100m', 'V_100m', 'topography' ] @@ -257,8 +257,8 @@ def test_fwp_multi_step_model_topo_noskip(): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '48km', @@ -266,8 +266,8 @@ def test_fwp_multi_step_model_topo_noskip(): _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '24km', @@ -277,8 +277,8 @@ def test_fwp_multi_step_model_topo_noskip(): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 st_model.meta['input_resolution'] = {'spatial': '12km', @@ -363,7 +363,7 @@ def test_fwp_multi_step_model_topo_noskip(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == [ + assert gan_meta[0]['lr_features'] == [ 'U_100m', 'V_100m', 'topography' ] @@ -438,8 +438,8 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) - model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + model.meta['hr_out_features'] = ['U_100m', 'V_100m'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 2 model.meta['input_resolution'] = {'spatial': '8km', @@ -494,7 +494,7 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): forward_pass = ForwardPass(handler) if plot: - for ifeature, feature in enumerate(forward_pass.output_features): + for ifeature, feature in enumerate(forward_pass.hr_out_features): fig = plt.figure(figsize=(15, 5)) ax1 = fig.add_subplot(111) vmin = np.min(forward_pass.input_data[..., ifeature]) @@ -581,8 +581,8 @@ def test_fwp_multi_step_wind_hi_res_topo(): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '48km', @@ -597,8 +597,8 @@ def test_fwp_multi_step_wind_hi_res_topo(): exogenous_data=exo_tmp) s2_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '24km', @@ -609,8 +609,8 @@ def test_fwp_multi_step_wind_hi_res_topo(): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 st_model.meta['input_resolution'] = {'spatial': '12km', @@ -763,8 +763,8 @@ def test_fwp_wind_hi_res_topo_plus_linear(): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) - s_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s_model.meta['output_features'] = ['U_100m', 'V_100m'] + s_model.meta['lr_features'] = ['U_100m', 'V_100m', 'topography'] + s_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s_model.meta['s_enhance'] = 2 s_model.meta['t_enhance'] = 1 s_model.meta['input_resolution'] = {'spatial': '12km', @@ -776,7 +776,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): _ = s_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=exo_tmp) - t_model = LinearInterp(features=['U_100m', 'V_100m'], + t_model = LinearInterp(lr_features=['U_100m', 'V_100m'], s_enhance=1, t_enhance=4) t_model.meta['input_resolution'] = {'spatial': '4km', @@ -839,10 +839,10 @@ def test_fwp_multi_step_model_multi_exo(): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = [ + s1_model.meta['lr_features'] = [ 'U_100m', 'V_100m', 'topography' ] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '48km', @@ -850,10 +850,10 @@ def test_fwp_multi_step_model_multi_exo(): _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = [ + s2_model.meta['lr_features'] = [ 'U_100m', 'V_100m', 'topography' ] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '24km', @@ -865,10 +865,10 @@ def test_fwp_multi_step_model_multi_exo(): st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) st_model.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} - st_model.meta['training_features'] = [ + st_model.meta['lr_features'] = [ 'U_100m', 'V_100m', 'sza' ] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) @@ -959,7 +959,7 @@ def test_fwp_multi_step_model_multi_exo(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == [ + assert gan_meta[0]['lr_features'] == [ 'U_100m', 'V_100m', 'topography' ] @@ -1082,10 +1082,10 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = Sup3rGan(gen_s_model, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = [ + s1_model.meta['lr_features'] = [ 'U_100m', 'V_100m', 'topography', 'sza' ] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '48km', @@ -1102,10 +1102,10 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): exogenous_data=exo_tmp) s2_model = Sup3rGan(gen_s_model, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = [ + s2_model.meta['lr_features'] = [ 'U_100m', 'V_100m', 'topography', 'sza' ] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '24km', @@ -1115,10 +1115,10 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(gen_t_model, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = [ + st_model.meta['lr_features'] = [ 'U_100m', 'V_100m', 'sza' ] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['hr_out_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 2 st_model.meta['input_resolution'] = {'spatial': '12km', diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index 3da773b44..8481e6c82 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -60,24 +60,30 @@ def test_multi_step_norm(norm_option): if norm_option == 'diff_stats': # models have different norm stats - model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) - model2.set_norm_stats([0.1, 0.2], [0.04, 0.02]) - model3.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.2}, + {'U_100m': 0.04, 'V_100m': 0.02}) + model2.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.2}, + {'U_100m': 0.04, 'V_100m': 0.02}) + model3.set_norm_stats({'U_100m': 0.3, 'V_100m': 0.9}, + {'U_100m': 0.02, 'V_100m': 0.07}) else: # all models have the same norm stats - model1.set_norm_stats([0.1, 0.8], [0.04, 0.02]) - model2.set_norm_stats([0.1, 0.8], [0.04, 0.02]) - model3.set_norm_stats([0.1, 0.8], [0.04, 0.02]) + model1.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.8}, + {'U_100m': 0.04, 'V_100m': 0.02}) + model2.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.8}, + {'U_100m': 0.04, 'V_100m': 0.02}) + model3.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.8}, + {'U_100m': 0.04, 'V_100m': 0.02}) model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} model3.meta['input_resolution'] = {'spatial': '3km', 'temporal': '4min'} - model1.set_model_params(training_features=FEATURES, - output_features=FEATURES) - model2.set_model_params(training_features=FEATURES, - output_features=FEATURES) - model3.set_model_params(training_features=FEATURES, - output_features=FEATURES) + model1.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) + model2.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) + model3.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) _ = model1.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) _ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) @@ -118,14 +124,18 @@ def test_spatial_then_temporal_gan(): model2 = Sup3rGan(fp_gen, fp_disc) _ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) - model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) - model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.2}, + {'U_100m': 0.04, 'V_100m': 0.02}) + model2.set_norm_stats({'U_100m': 0.3, 'V_100m': 0.9}, + {'U_100m': 0.02, 'V_100m': 0.07}) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} - model1.set_model_params(training_features=FEATURES, - output_features=FEATURES) - model2.set_model_params(training_features=FEATURES, - output_features=FEATURES) + + model1.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) + model2.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -152,14 +162,18 @@ def test_temporal_then_spatial_gan(): model2 = Sup3rGan(fp_gen, fp_disc) _ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) - model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) - model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.2}, + {'U_100m': 0.04, 'V_100m': 0.02}) + model2.set_norm_stats({'U_100m': 0.3, 'V_100m': 0.9}, + {'U_100m': 0.02, 'V_100m': 0.07}) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} - model1.set_model_params(training_features=FEATURES, - output_features=FEATURES) - model2.set_model_params(training_features=FEATURES, - output_features=FEATURES) + + model1.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) + model2.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -181,12 +195,13 @@ def test_spatial_gan_then_linear_interp(): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(FEATURES)))) - model2 = LinearInterp(features=FEATURES, s_enhance=3, t_enhance=4) + model2 = LinearInterp(lr_features=FEATURES, s_enhance=3, t_enhance=4) - model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) + model1.set_norm_stats({'U_100m': 0.1, 'V_100m': 0.2}, + {'U_100m': 0.04, 'V_100m': 0.02}) model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} - model1.set_model_params(training_features=FEATURES, - output_features=FEATURES) + model1.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') @@ -210,20 +225,21 @@ def test_solar_multistep(): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) - model1.set_norm_stats([0.7], [0.04]) + model1.set_norm_stats({'clearsky_ratio': 0.7}, {'clearsky_ratio': 0.04}) model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} - model1.set_model_params(training_features=features1, - output_features=features1) + model1.set_model_params(lr_features=features1, + hr_out_features=features1) features2 = ['U_200m', 'V_200m'] fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') model2 = Sup3rGan(fp_gen, fp_disc) _ = model2.generate(np.ones((4, 10, 10, len(features2)))) - model2.set_norm_stats([4.2, 5.6], [1.1, 1.3]) + model2.set_norm_stats({'U_200m': 4.2, 'V_200m': 5.6}, + {'U_200m': 1.1, 'V_200m': 1.3}) model2.meta['input_resolution'] = {'spatial': '4km', 'temporal': '40min'} - model2.set_model_params(training_features=features2, - output_features=features2) + model2.set_model_params(lr_features=features2, + hr_out_features=features2) features_in_3 = ['clearsky_ratio', 'U_200m', 'V_200m'] features_out_3 = ['clearsky_ratio'] @@ -231,10 +247,13 @@ def test_solar_multistep(): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') model3 = Sup3rGan(fp_gen, fp_disc) _ = model3.generate(np.ones((4, 10, 10, 3, len(features_in_3)))) - model3.set_norm_stats([0.7, 4.2, 5.6], [0.04, 1.1, 1.3]) + model3.set_norm_stats({'U_200m': 4.2, 'V_200m': 5.6, + 'clearsky_ratio': 0.7}, + {'U_200m': 1.1, 'V_200m': 1.3, + 'clearsky_ratio': 0.04}) model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} - model3.set_model_params(training_features=features_in_3, - output_features=features_out_3) + model3.set_model_params(lr_features=features_in_3, + hr_out_features=features_out_3) with tempfile.TemporaryDirectory() as td: fp1 = os.path.join(td, 'model1') diff --git a/tests/forward_pass/test_out_conditional_moments.py b/tests/forward_pass/test_out_conditional_moments.py index 229b2e450..5f6c62863 100644 --- a/tests/forward_pass/test_out_conditional_moments.py +++ b/tests/forward_pass/test_out_conditional_moments.py @@ -44,7 +44,7 @@ def test_out_s_mom1(FEATURES, TRAIN_FEATURES, s_enhance=2, model_dir=None): """Test basic spatial model outputing.""" handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -139,7 +139,7 @@ def test_out_s_mom1_sf(FEATURES, TRAIN_FEATURES, s_enhance=2, model_dir=None): """Test basic spatial model outputing.""" handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -234,7 +234,7 @@ def test_out_s_mom2(FEATURES, TRAIN_FEATURES, model_mom1_dir=None): """Test basic spatial model outputing.""" handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -324,7 +324,7 @@ def test_out_s_mom2_sf(FEATURES, TRAIN_FEATURES, model_mom1_dir=None): """Test basic spatial model outputing.""" handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -430,7 +430,7 @@ def test_out_s_mom2_sep(FEATURES, TRAIN_FEATURES, """Test basic spatial model outputing for second conditional, moment separate from the first moment""" handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -533,7 +533,7 @@ def test_out_s_mom2_sep_sf(FEATURES, TRAIN_FEATURES, model_mom1_dir=None): """Test basic spatial model outputing.""" handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), diff --git a/tests/forward_pass/test_surface_model.py b/tests/forward_pass/test_surface_model.py index d4358e51e..3cb326c47 100644 --- a/tests/forward_pass/test_surface_model.py +++ b/tests/forward_pass/test_surface_model.py @@ -54,13 +54,15 @@ def test_surface_model(s_enhance=5): low_res, true_hi_res, topo_lr, topo_hr = get_inputs(s_enhance) - kwargs = {'meta': {'features': FEATURES, 's_enhance': s_enhance}} + kwargs = {'meta': {'lr_features': FEATURES, 'hr_out_features': FEATURES, + 's_enhance': s_enhance}} with tempfile.TemporaryDirectory() as td: fp_params = os.path.join(td, 'model_params.json') with open(fp_params, 'w') as f: json.dump(kwargs, f) model = SurfaceSpatialMetModel.load(model_dir=td) + exo_tmp = {'topography': {'steps': [{'data': topo_lr}, {'data': topo_hr}]}} hi_res = model.generate(low_res, exogenous_data=exo_tmp) @@ -119,9 +121,12 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): model = Sup3rGan(config_gen, fp_disc) _ = model.generate(np.ones((4, 10, 10, 6, len(FEATURES)))) - model.set_norm_stats([0.3, 0.9, 0.1], [0.02, 0.07, 0.03]) - model.set_model_params(training_features=FEATURES, - output_features=FEATURES, + model.set_norm_stats({'temperature_2m': 0.3, 'relativehumidity_2m': 0.9, + 'pressure_0m': 0.1}, + {'temperature_2m': 0.02, 'relativehumidity_2m': 0.07, + 'pressure_0m': 0.03}) + model.set_model_params(lr_features=FEATURES, + hr_out_features=FEATURES, input_resolution={'spatial': '30km', 'temporal': '60min'}, s_enhance=1, @@ -131,7 +136,8 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): temporal_dir = os.path.join(td, 'model') model.save(temporal_dir) - surface_model_kwargs = {'meta': {'features': FEATURES, + surface_model_kwargs = {'meta': {'lr_features': FEATURES, + 'hr_out_features': FEATURES, 's_enhance': s_enhance}} surface_dir = os.path.join(td, 'surface/') diff --git a/tests/output/test_qa.py b/tests/output/test_qa.py index 06ebb0a04..8af43173a 100644 --- a/tests/output/test_qa.py +++ b/tests/output/test_qa.py @@ -40,8 +40,8 @@ def test_qa_nc(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(TRAIN_FEATURES)))) - model.meta['training_features'] = TRAIN_FEATURES - model.meta['output_features'] = MODEL_OUT_FEATURES + model.meta['lr_features'] = TRAIN_FEATURES + model.meta['hr_out_features'] = MODEL_OUT_FEATURES model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 with tempfile.TemporaryDirectory() as td: @@ -123,8 +123,8 @@ def test_qa_h5(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(TRAIN_FEATURES)))) - model.meta['training_features'] = TRAIN_FEATURES - model.meta['output_features'] = MODEL_OUT_FEATURES + model.meta['lr_features'] = TRAIN_FEATURES + model.meta['hr_out_features'] = MODEL_OUT_FEATURES model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 with tempfile.TemporaryDirectory() as td: @@ -213,8 +213,8 @@ def test_stats(log=False): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(TRAIN_FEATURES)))) - model.meta['training_features'] = TRAIN_FEATURES - model.meta['output_features'] = MODEL_OUT_FEATURES + model.meta['lr_features'] = TRAIN_FEATURES + model.meta['hr_out_features'] = MODEL_OUT_FEATURES model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 with tempfile.TemporaryDirectory() as td: diff --git a/tests/pipeline/test_cli.py b/tests/pipeline/test_cli.py index c2cea8e49..dbb52a7cd 100644 --- a/tests/pipeline/test_cli.py +++ b/tests/pipeline/test_cli.py @@ -44,8 +44,8 @@ def test_pipeline_fwp_collect(runner, log=False): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:2] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:2] model.meta['s_enhance'] = 3 model.meta['t_enhance'] = 4 @@ -207,8 +207,8 @@ def test_fwd_pass_cli(runner, log=False): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:2] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:2] assert model.s_enhance == 3 assert model.t_enhance == 4 @@ -305,8 +305,8 @@ def test_pipeline_fwp_qa(runner, log=False): assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:2] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:2] assert model.s_enhance == 3 assert model.t_enhance == 4 diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 52b195a15..69379c6ac 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -33,8 +33,9 @@ def test_fwp_pipeline(): assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:2] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:2] + model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 @@ -134,8 +135,9 @@ def test_multiple_fwp_pipeline(): assert model.input_resolution == input_resolution assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) - model.meta['training_features'] = FEATURES - model.meta['output_features'] = FEATURES[:2] + model.meta['lr_features'] = FEATURES + model.meta['hr_out_features'] = FEATURES[:2] + model.meta['hr_exo_features'] = FEATURES[2:] assert model.s_enhance == 3 assert model.t_enhance == 4 diff --git a/tests/training/test_train_conditional_moments.py b/tests/training/test_train_conditional_moments.py index 71fbfb281..56354e8f0 100644 --- a/tests/training/test_train_conditional_moments.py +++ b/tests/training/test_train_conditional_moments.py @@ -61,7 +61,7 @@ def test_train_s_mom1(FEATURES, TRAIN_FEATURES, model = Sup3rCondMom(fp_gen, learning_rate=1e-4) handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -151,7 +151,7 @@ def test_train_s_mom1_sf(FEATURES, TRAIN_FEATURES, model = Sup3rCondMom(fp_gen, learning_rate=1e-4) handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -214,7 +214,7 @@ def test_train_s_mom2(FEATURES, TRAIN_FEATURES, model_mom2 = Sup3rCondMom(fp_gen_mom2, learning_rate=1e-4) handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -276,7 +276,7 @@ def test_train_s_mom2_sf(FEATURES, TRAIN_FEATURES, model_mom2 = Sup3rCondMom(fp_gen_mom2, learning_rate=1e-4) handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -332,7 +332,7 @@ def test_train_s_mom2_sep(FEATURES, TRAIN_FEATURES, model_mom2 = Sup3rCondMom(fp_gen_mom2, learning_rate=1e-4) handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), @@ -387,7 +387,7 @@ def test_train_s_mom2_sep_sf(FEATURES, TRAIN_FEATURES, model_mom2 = Sup3rCondMom(fp_gen_mom2, learning_rate=1e-4) handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, - train_only_features=TRAIN_FEATURES, + lr_only_features=TRAIN_FEATURES, shape=full_shape, sample_shape=sample_shape, temporal_slice=slice(None, None, 10), diff --git a/tests/training/test_train_gan_exo.py b/tests/training/test_train_gan_exo.py index 708112e4f..ba8ca783e 100644 --- a/tests/training/test_train_gan_exo.py +++ b/tests/training/test_train_gan_exo.py @@ -35,8 +35,8 @@ TARGET_COORD = (39.01, -105.15) -@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat']) -def test_wind_hi_res_topo_with_train_only(custom_layer, log=False): +@pytest.mark.parametrize('CustomLayer', ['Sup3rAdder', 'Sup3rConcat']) +def test_wind_hi_res_topo_with_train_only(CustomLayer, log=False): """Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the network. This also includes a train only feature""" @@ -49,7 +49,8 @@ def test_wind_hi_res_topo_with_train_only(custom_layer, log=False): val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=['temperature_100m']) + lr_only_features=['temperature_100m'], + hr_exo_features=['topography']) batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -79,7 +80,7 @@ def test_wind_hi_res_topo_with_train_only(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer, "name": "topography"}, + {"class": CustomLayer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -103,13 +104,14 @@ def test_wind_hi_res_topo_with_train_only(custom_layer, log=False): checkpoint_int=None, out_dir=os.path.join(td, 'test_{epoch}')) - assert model.train_only_features == ['temperature_100m'] - assert model.hr_features == ['U_100m', 'V_100m', 'topography'] + assert model.lr_features == FEATURES_W + assert model.hr_out_features == ['U_100m', 'V_100m'] + assert model.hr_exo_features == ['topography'] assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] + assert model.meta['hr_out_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rGan' - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features + assert 'topography' in batcher.hr_exo_features + assert 'topography' not in model.hr_out_features x = np.random.uniform(0, 1, (4, 30, 30, 4)) hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) @@ -129,8 +131,8 @@ def test_wind_hi_res_topo_with_train_only(custom_layer, log=False): assert y.shape[3] == x.shape[3] - 2 -@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat']) -def test_wind_hi_res_topo(custom_layer, log=False): +@pytest.mark.parametrize('CustomLayer', ['Sup3rAdder', 'Sup3rConcat']) +def test_wind_hi_res_topo(CustomLayer, log=False): """Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the network.""" @@ -143,7 +145,8 @@ def test_wind_hi_res_topo(custom_layer, log=False): val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=()) + lr_only_features=(), + hr_exo_features=('topography',)) batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -174,7 +177,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer, "name": "topography"}, + {"class": CustomLayer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -199,10 +202,10 @@ def test_wind_hi_res_topo(custom_layer, log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] + assert model.meta['hr_out_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rGan' - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features + assert 'topography' in batcher.hr_exo_features + assert 'topography' not in model.hr_out_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) @@ -222,8 +225,8 @@ def test_wind_hi_res_topo(custom_layer, log=False): assert y.shape[3] == x.shape[3] - 1 -@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat']) -def test_wind_non_cc_hi_res_topo(custom_layer, log=False): +@pytest.mark.parametrize('CustomLayer', ['Sup3rAdder', 'Sup3rConcat']) +def test_wind_non_cc_hi_res_topo(CustomLayer, log=False): """Test a special wind model for non cc with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the network.""" @@ -235,7 +238,8 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + lr_only_features=tuple(), + hr_exo_features=('topography',)) batcher = SpatialBatchHandler([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -266,7 +270,7 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer, "name": "topography"}, + {"class": CustomLayer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -291,10 +295,10 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] + assert model.meta['hr_out_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rGan' - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features + assert 'topography' in batcher.hr_exo_features + assert 'topography' not in model.hr_out_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) @@ -314,8 +318,8 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): assert y.shape[3] == x.shape[3] - 1 -@pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat']) -def test_wind_dc_hi_res_topo(custom_layer, log=False): +@pytest.mark.parametrize('CustomLayer', ['Sup3rAdder', 'Sup3rConcat']) +def test_wind_dc_hi_res_topo(CustomLayer, log=False): """Test a special data centric wind model with the custom Sup3rAdder or Sup3rConcat layer that adds/concatenates hi-res topography in the middle of the network.""" @@ -327,7 +331,8 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): val_split=0.0, sample_shape=(20, 20, 8), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + lr_only_features=tuple(), + hr_exo_features=('topography',)) batcher = BatchHandlerDC([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -358,7 +363,7 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): {"class": "SpatioTemporalExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer, "name": "topography"}, + {"class": CustomLayer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], @@ -383,10 +388,10 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] + assert model.meta['hr_out_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rGanDC' - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features + assert 'topography' in batcher.hr_exo_features + assert 'topography' not in model.hr_out_features x = np.random.uniform(0, 1, (1, 30, 30, 4, 3)) hi_res_topo = np.random.uniform(0, 1, (1, 60, 60, 4, 1)) diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 9b3a14bca..56c94f7a5 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -63,7 +63,7 @@ def test_solar_cc_model(log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['clearsky_ratio'] + assert model.meta['hr_out_features'] == ['clearsky_ratio'] assert model.meta['class'] == 'Sup3rGan' out_dir = os.path.join(td, 'cc_gan') @@ -119,7 +119,7 @@ def test_solar_cc_model_spatial(log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['clearsky_ratio'] + assert model.meta['hr_out_features'] == ['clearsky_ratio'] assert model.meta['class'] == 'Sup3rGan' x = np.random.uniform(0, 1, (4, 10, 10, 1)) diff --git a/tests/training/test_train_wind_conditional_moments.py b/tests/training/test_train_wind_conditional_moments.py index 20cd092d2..468b030a0 100644 --- a/tests/training/test_train_wind_conditional_moments.py +++ b/tests/training/test_train_wind_conditional_moments.py @@ -98,7 +98,8 @@ def test_wind_non_cc_hi_res_topo_mom1(custom_layer, batch_class, val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + lr_only_features=tuple(), + hr_exo_features=('topography',)) batcher = batch_class([handler], batch_size=batch_size, @@ -119,11 +120,11 @@ def test_wind_non_cc_hi_res_topo_mom1(custom_layer, batch_class, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) assert f'test_{n_epoch-1}' in os.listdir(out_dir_root) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] + assert model.meta['hr_out_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rCondMom' assert model.meta['input_resolution'] == input_resolution - assert 'topography' in batcher.output_features - assert 'topography' not in model.output_features + assert 'topography' in batcher.hr_exo_features + assert 'topography' not in model.hr_out_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) @@ -160,7 +161,8 @@ def test_wind_non_cc_hi_res_st_topo_mom1(batch_class, log=False, val_split=0.1, sample_shape=(12, 12, 24), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + lr_only_features=tuple(), + hr_exo_features=('topography',)) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc', @@ -210,7 +212,8 @@ def test_wind_non_cc_hi_res_topo_mom2(custom_layer, batch_class, val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + lr_only_features=tuple(), + hr_exo_features=('topography',)) gen_model = make_s_gen_model(custom_layer) @@ -258,7 +261,8 @@ def test_wind_non_cc_hi_res_st_topo_mom2(batch_class, log=False, val_split=0.1, sample_shape=(12, 12, 24), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + lr_only_features=tuple(), + hr_exo_features=('topography',)) fp_gen = os.path.join(CONFIG_DIR, 'sup3rcc',