From 2b9c0d03e5757b4121153c57006087a996aba992 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Jan 2025 15:42:25 -0700 Subject: [PATCH] ``Sup3rGanWithObs`` model subclass. Other misc model refactoring. --- sup3r/models/__init__.py | 1 + sup3r/models/abstract.py | 145 ++++----- sup3r/models/base.py | 186 ++++++++--- sup3r/models/interface.py | 27 -- sup3r/models/with_obs.py | 354 +++++++++++++++++++++ tests/training/test_train_dual_with_obs.py | 10 +- 6 files changed, 562 insertions(+), 161 deletions(-) create mode 100644 sup3r/models/with_obs.py diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 5d6b51344..20fff799f 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -6,6 +6,7 @@ from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel +from .with_obs import Sup3rGanWithObs SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 579dab5e9..dba014ea3 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -16,7 +16,6 @@ from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat from rex.utilities.utilities import safe_json_load from tensorflow.keras import optimizers -from tensorflow.keras.losses import MeanAbsoluteError import sup3r.utilities.loss_metrics from sup3r.preprocessing.data_handlers import ExoData @@ -420,6 +419,30 @@ def get_high_res_exo_input(self, high_res): exo_data[feature] = exo_fdata return exo_data + @tf.function + def _combine_loss_input(self, high_res_true, high_res_gen): + """Combine exogenous feature data from high_res_true with high_res_gen + for loss calculation + + Parameters + ---------- + high_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + high_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + high_res_gen : tf.Tensor + Same as input with exogenous data combined with high_res input + """ + if high_res_true.shape[-1] > high_res_gen.shape[-1]: + exo_dict = self.get_high_res_exo_input(high_res_true) + exo_data = [exo_dict[feat] for feat in self.hr_exo_features] + high_res_gen = tf.concat((high_res_gen, *exo_data), axis=-1) + return high_res_gen + @staticmethod def get_loss_fun(loss): """Get the initialized loss function class from the sup3r loss library @@ -717,25 +740,42 @@ def finish_epoch( return stop + def _sum_parallel_grad(self, futures, start_time): + """Sum gradient descent future results""" + + # sum the gradients from each gpu to weight equally in + # optimizer momentum calculation + total_grad = None + for future in futures: + grad, loss_details = future.result() + if total_grad is None: + total_grad = grad + else: + for i, igrad in enumerate(grad): + total_grad[i] += igrad + + msg = ( + f'Finished {len(futures)} gradient descent steps on ' + f'{len(self.gpu_list)} GPUs in {time.time() - start_time:.4f} ' + 'seconds' + ) + logger.info(msg) + return total_grad, loss_details + def _get_parallel_grad( self, low_res, hi_res_true, training_weights, - obs_data=None, **calc_loss_kwargs, ): """Compute gradient for one mini-batch of (low_res, hi_res_true) across multiple GPUs""" futures = [] + start_time = time.time() lr_chunks = np.array_split(low_res, len(self.gpu_list)) hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) - obs_data_chunks = ( - [None] * len(hr_true_chunks) - if obs_data is None - else np.array_split(obs_data, len(self.gpu_list)) - ) split_mask = False mask_chunks = None if 'mask' in calc_loss_kwargs: @@ -754,38 +794,17 @@ def _get_parallel_grad( lr_chunks[i], hr_true_chunks[i], training_weights, - obs_data=obs_data_chunks[i], device_name=f'/gpu:{i}', **calc_loss_kwargs, ) ) - - # sum the gradients from each gpu to weight equally in - # optimizer momentum calculation - total_grad = None - for future in futures: - grad, loss_details = future.result() - if total_grad is None: - total_grad = grad - else: - for i, igrad in enumerate(grad): - total_grad[i] += igrad - - self.timer.stop() - logger.debug( - 'Finished %s gradient descent steps on %s GPUs in %s', - len(futures), - len(self.gpu_list), - self.timer.elapsed_str, - ) - return total_grad, loss_details + return self._sum_parallel_grad(futures, start_time=start_time) def run_gradient_descent( self, low_res, hi_res_true, training_weights, - obs_data=None, optimizer=None, multi_gpu=False, **calc_loss_kwargs, @@ -806,10 +825,6 @@ def run_gradient_descent( training_weights : list A list of layer weights that are to-be-trained based on the current loss weight values. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) optimizer : tf.keras.optimizers.Optimizer Optimizer class to use to update weights. This can be different if you're training just the generator or one of the discriminator @@ -829,32 +844,27 @@ def run_gradient_descent( loss_details : dict Namespace of the breakdown of loss components """ - - self.timer.start() if optimizer is None: optimizer = self.optimizer if not multi_gpu or len(self.gpu_list) < 2: + start_time = time.time() grad, loss_details = self.get_single_grad( low_res, hi_res_true, training_weights, - obs_data=obs_data, device_name=self.default_device, **calc_loss_kwargs, ) optimizer.apply_gradients(zip(grad, training_weights)) - self.timer.stop() - logger.debug( - 'Finished single gradient descent step in %s', - self.timer.elapsed_str, - ) + msg = ('Finished single gradient descent step in ' + f'{time.time() - start_time:.4f} seconds') + logger.debug(msg) else: total_grad, loss_details = self._get_parallel_grad( low_res, hi_res_true, training_weights, - obs_data, **calc_loss_kwargs, ) optimizer.apply_gradients(zip(total_grad, training_weights)) @@ -1050,13 +1060,24 @@ def _tf_generate(self, low_res, hi_res_exo=None): return hi_res + def _get_hr_exo_and_loss( + self, + low_res, + hi_res_true, + **calc_loss_kwargs, + ): + """Get high-resolution exogenous data, generate synthetic output, and + compute loss.""" + hi_res_exo = self.get_high_res_exo_input(hi_res_true) + hi_res_gen = self._tf_generate(low_res, hi_res_exo) + return self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) + @tf.function def get_single_grad( self, low_res, hi_res_true, training_weights, - obs_data=None, device_name=None, **calc_loss_kwargs, ): @@ -1076,10 +1097,6 @@ def get_single_grad( training_weights : list A list of layer weights that are to-be-trained based on the current loss weight values. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) device_name : None | str Optional tensorflow device name for GPU placement. Note that if a GPU is available, variables will be placed on that GPU even if @@ -1100,16 +1117,10 @@ def get_single_grad( watch_accessed_variables=False ) as tape: tape.watch(training_weights) - hi_res_exo = self.get_high_res_exo_input(hi_res_true) - hi_res_gen = self._tf_generate(low_res, hi_res_exo) - loss_out = self.calc_loss( - hi_res_true, hi_res_gen, **calc_loss_kwargs + loss_out = self._get_hr_exo_and_loss( + low_res, hi_res_true, **calc_loss_kwargs ) loss, loss_details = loss_out - if obs_data is not None: - loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) - loss += loss_obs - loss_details['loss_obs'] = loss_obs grad = tape.gradient(loss, training_weights) return grad, loss_details @@ -1124,27 +1135,3 @@ def calc_loss( ): """Calculate the GAN loss function using generated and true high resolution data.""" - - @tf.function - def calc_loss_obs(self, obs_data, hi_res_gen): - """Calculate loss term for the observation data vs generated - high-resolution data - - Parameters - ---------- - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - - Returns - ------- - loss : tf.Tensor - 0D tensor of observation loss - """ - mask = tf.math.is_nan(obs_data) - return MeanAbsoluteError()( - obs_data[~mask], - hi_res_gen[..., : len(self.hr_out_features)][~mask], - ) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 53b159464..fdf15f031 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -750,6 +750,7 @@ def train( ) for epoch in epochs: + t_epoch = time.time() loss_details = self.train_epoch( batch_handler, weight_gen_advers, @@ -813,12 +814,18 @@ def train( early_stop_n_epoch, extras=extras, ) + logger.info( + 'Finished training epoch in {:.4f} seconds'.format( + time.time() - t_epoch + ) + ) if stop: break logger.info( - 'Finished training %s epochs in %s seconds', - n_epoch, - time.time() - t0, + 'Finished training {} epochs in {:.4f} seconds'.format( + n_epoch, + time.time() - t0, + ) ) batch_handler.stop() @@ -842,8 +849,6 @@ def calc_loss( hi_res_gen : tf.Tensor Superresolved high resolution spatiotemporal data generated by the generative model. - obs_data : tf.Tensor | None - Optional observation data to use in additional content loss term. weight_gen_advers : float Weight factor for the adversarial loss component of the generator vs. the discriminator. @@ -897,6 +902,37 @@ def calc_loss( return loss, loss_details + def _calc_val_loss(self, batch, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + for a given batch + + Parameters + ---------- + batch : DsetTuple + Object with ``.high_res`` and ``.low_res`` arrays + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input but now includes val_* loss info + """ + _, v_loss_details = self._get_hr_exo_and_loss( + batch.low_res, + batch.high_res, + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(batch), prefix='val_' + ) + return loss_details + def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): """Calculate the validation loss at the current state of model training @@ -918,25 +954,93 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self.get_high_res_exo_input(val_batch.high_res) - high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) - _, v_loss_details = self.calc_loss( - val_batch.high_res, - high_res_gen, + loss_details = self._calc_val_loss( + val_batch, weight_gen_advers, loss_details + ) + return loss_details + + def _get_batch_loss_details( + self, + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu=False, + ): + """Get loss details for a given batch for the current epoch. + + Parameters + ---------- + batch : sup3r.preprocessing.base.DsetTuple + Object with ``.low_res`` and ``.high_res`` arrays + train_gen : bool + Flag whether to train the generator for this set of epochs + only_gen : bool + Flag whether to only train the generator for this set of epochs + gen_too_good : bool + Flag whether to skip training the generator and only train the + discriminator, due to superior performance, for this batch. + train_disc : bool + Flag whether to train the discriminator for this set of epochs + only_disc : bool + Flag whether to only train the discriminator for this set of epochs + gen_too_good : bool + Flag whether to skip training the discriminator and only train the + generator, due to superior performance, for this batch. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components for the given batch + """ + + trained_gen = False + trained_disc = False + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, weight_gen_advers=weight_gen_advers, - train_gen=False, + optimizer=self.optimizer, + train_gen=True, train_disc=False, + multi_gpu=multi_gpu, ) - obs_data = getattr(val_batch, 'obs', None) - if obs_data is not None: - v_loss_details['loss_obs'] = self.cal_loss_obs( - obs_data, high_res_gen - ) - loss_details = self.update_loss_details( - loss_details, v_loss_details, len(val_batch), prefix='val_' + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, ) - return loss_details + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + return b_loss_details def train_epoch( self, @@ -991,8 +1095,6 @@ def train_epoch( tf.summary.trace_on(graph=True, profiler=True) for ib, batch in enumerate(batch_handler): - trained_gen = False - trained_disc = False b_loss_details = {} loss_disc = loss_details['train_loss_disc'] disc_too_good = loss_disc <= disc_th_low @@ -1002,35 +1104,19 @@ def train_epoch( if not self.generator_weights: self.init_weights(batch.low_res.shape, batch.high_res.shape) - if only_gen or (train_gen and not gen_too_good): - trained_gen = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.generator_weights, - obs_data=getattr(batch, 'obs', None), - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer, - train_gen=True, - train_disc=False, - multi_gpu=multi_gpu, - ) - - if only_disc or (train_disc and not disc_too_good): - trained_disc = True - b_loss_details = self.timer(self.run_gradient_descent)( - batch.low_res, - batch.high_res, - self.discriminator_weights, - weight_gen_advers=weight_gen_advers, - optimizer=self.optimizer_disc, - train_gen=False, - train_disc=True, - multi_gpu=multi_gpu, - ) - - b_loss_details['gen_trained_frac'] = float(trained_gen) - b_loss_details['disc_trained_frac'] = float(trained_disc) + b_loss_details = self._get_batch_loss_details( + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu, + ) + trained_gen = bool(b_loss_details.get('gen_trained_frac', False)) + trained_disc = bool(b_loss_details.get('disc_trained_frac', False)) self.dict_to_tensorboard(b_loss_details) self.dict_to_tensorboard(self.timer.log) diff --git a/sup3r/models/interface.py b/sup3r/models/interface.py index c81065f86..e284cc21a 100644 --- a/sup3r/models/interface.py +++ b/sup3r/models/interface.py @@ -9,7 +9,6 @@ from warnings import warn import numpy as np -import tensorflow as tf from phygnn import CustomNetwork from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat @@ -355,32 +354,6 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): hi_res = np.concatenate((hi_res, exo_output), axis=-1) return hi_res - @tf.function - def _combine_loss_input(self, high_res_true, high_res_gen): - """Combine exogenous feature data from high_res_true with high_res_gen - for loss calculation - - Parameters - ---------- - high_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - high_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - - Returns - ------- - high_res_gen : tf.Tensor - Same as input with exogenous data combined with high_res input - """ - if high_res_true.shape[-1] > high_res_gen.shape[-1]: - for feature in self.hr_exo_features: - f_idx = self.hr_exo_features.index(feature) - f_idx += len(self.hr_out_features) - exo_data = high_res_true[..., f_idx : f_idx + 1] - high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) - return high_res_gen - @property @abstractmethod def meta(self): diff --git a/sup3r/models/with_obs.py b/sup3r/models/with_obs.py new file mode 100644 index 000000000..985cc22d7 --- /dev/null +++ b/sup3r/models/with_obs.py @@ -0,0 +1,354 @@ +"""Sup3r model with training on observation data.""" + +import logging +import time +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import tensorflow as tf +from tensorflow.keras.losses import MeanAbsoluteError + +from .base import Sup3rGan + +logger = logging.getLogger(__name__) + + +class Sup3rGanWithObs(Sup3rGan): + """Sup3r GAN model which incorporates observation data into content loss. + """ + + def _calc_val_loss(self, batch, weight_gen_advers, loss_details): + """Calculate the validation loss at the current state of model training + for a given batch + + Parameters + ---------- + batch : DsetTuple + Object with ``.high_res``, ``.low_res``, and ``.obs`` arrays + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + loss_details : dict + Namespace of the breakdown of loss components + + Returns + ------- + loss_details : dict + Same as input with updated val_* loss info + """ + val_exo_data = self.get_high_res_exo_input(batch.high_res) + high_res_gen = self._tf_generate(batch.low_res, val_exo_data) + _, v_loss_details = self.calc_loss( + batch.high_res, + high_res_gen, + weight_gen_advers=weight_gen_advers, + train_gen=False, + train_disc=False, + ) + v_loss_details['loss_obs'] = self.cal_loss_obs(batch.obs, high_res_gen) + + loss_details = self.update_loss_details( + loss_details, v_loss_details, len(batch), prefix='val_' + ) + return loss_details + + def _get_batch_loss_details( + self, + batch, + train_gen, + only_gen, + gen_too_good, + train_disc, + only_disc, + disc_too_good, + weight_gen_advers, + multi_gpu=False, + ): + """Get loss details for a given batch for the current epoch. + + Parameters + ---------- + batch : sup3r.preprocessing.base.DsetTuple + Object with ``.low_res``, ``.high_res``, and ``.obs`` arrays + train_gen : bool + Flag whether to train the generator for this set of epochs + only_gen : bool + Flag whether to only train the generator for this set of epochs + gen_too_good : bool + Flag whether to skip training the generator and only train the + discriminator, due to superior performance, for this batch. + train_disc : bool + Flag whether to train the discriminator for this set of epochs + only_disc : bool + Flag whether to only train the discriminator for this set of epochs + gen_too_good : bool + Flag whether to skip training the discriminator and only train the + generator, due to superior performance, for this batch. + weight_gen_advers : float + Weight factor for the adversarial loss component of the generator + vs. the discriminator. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + If true and multiple gpus are found, ``default_device`` device + should be set to /gpu:0 + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components for the given batch + """ + trained_gen = False + trained_disc = False + if only_gen or (train_gen and not gen_too_good): + trained_gen = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.generator_weights, + obs_data=getattr(batch, 'obs', None), + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer, + train_gen=True, + train_disc=False, + multi_gpu=multi_gpu, + ) + + if only_disc or (train_disc and not disc_too_good): + trained_disc = True + b_loss_details = self.timer(self.run_gradient_descent)( + batch.low_res, + batch.high_res, + self.discriminator_weights, + weight_gen_advers=weight_gen_advers, + optimizer=self.optimizer_disc, + train_gen=False, + train_disc=True, + multi_gpu=multi_gpu, + ) + + b_loss_details['gen_trained_frac'] = float(trained_gen) + b_loss_details['disc_trained_frac'] = float(trained_disc) + return b_loss_details + + def _get_parallel_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + **calc_loss_kwargs, + ): + """Compute gradient for one mini-batch of (low_res, hi_res_true, + obs_data) across multiple GPUs. Can include observation data as well. + """ + + futures = [] + start_time = time.time() + lr_chunks = np.array_split(low_res, len(self.gpu_list)) + hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) + obs_data_chunks = ( + [None] * len(hr_true_chunks) + if obs_data is None + else np.array_split(obs_data, len(self.gpu_list)) + ) + split_mask = False + mask_chunks = None + if 'mask' in calc_loss_kwargs: + split_mask = True + mask_chunks = np.array_split( + calc_loss_kwargs['mask'], len(self.gpu_list) + ) + + with ThreadPoolExecutor(max_workers=len(self.gpu_list)) as exe: + for i in range(len(self.gpu_list)): + if split_mask: + calc_loss_kwargs['mask'] = mask_chunks[i] + futures.append( + exe.submit( + self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + obs_data=obs_data_chunks[i], + device_name=f'/gpu:{i}', + **calc_loss_kwargs, + ) + ) + + return self._sum_parallel_grad(futures, start_time=start_time) + + def run_gradient_descent( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + optimizer=None, + multi_gpu=False, + **calc_loss_kwargs, + ): + """Run gradient descent for one mini-batch of (low_res, hi_res_true) + and update weights + + Parameters + ---------- + low_res : np.ndarray + Real low-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + hi_res_true : np.ndarray + Real high-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + training_weights : list + A list of layer weights that are to-be-trained based on the + current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + optimizer : tf.keras.optimizers.Optimizer + Optimizer class to use to update weights. This can be different if + you're training just the generator or one of the discriminator + models. Defaults to the generator optimizer. + multi_gpu : bool + Flag to break up the batch for parallel gradient descent + calculations on multiple gpus. If True and multiple GPUs are + present, each batch from the batch_handler will be divided up + between the GPUs and resulting gradients from each GPU will be + summed and then applied once per batch at the nominal learning + rate that the model and optimizer were initialized with. + calc_loss_kwargs : dict + Kwargs to pass to the self.calc_loss() method + + Returns + ------- + loss_details : dict + Namespace of the breakdown of loss components + """ + + self.timer.start() + if optimizer is None: + optimizer = self.optimizer + + if not multi_gpu or len(self.gpu_list) < 2: + grad, loss_details = self.get_single_grad( + low_res, + hi_res_true, + training_weights, + obs_data=obs_data, + device_name=self.default_device, + **calc_loss_kwargs, + ) + optimizer.apply_gradients(zip(grad, training_weights)) + self.timer.stop() + logger.debug( + 'Finished single gradient descent step in %s', + self.timer.elapsed_str, + ) + else: + total_grad, loss_details = self._get_parallel_grad( + low_res, + hi_res_true, + training_weights, + obs_data, + **calc_loss_kwargs, + ) + optimizer.apply_gradients(zip(total_grad, training_weights)) + + return loss_details + + @tf.function + def get_single_grad( + self, + low_res, + hi_res_true, + training_weights, + obs_data=None, + device_name=None, + **calc_loss_kwargs, + ): + """Run gradient descent for one mini-batch of (low_res, hi_res_true), + do not update weights, just return gradient details. + + Parameters + ---------- + low_res : np.ndarray + Real low-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + hi_res_true : np.ndarray + Real high-resolution data in a 4D or 5D array: + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + training_weights : list + A list of layer weights that are to-be-trained based on the + current loss weight values. + obs_data : tf.Tensor | None + Optional observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) + device_name : None | str + Optional tensorflow device name for GPU placement. Note that if a + GPU is available, variables will be placed on that GPU even if + device_name=None. + calc_loss_kwargs : dict + Kwargs to pass to the self.calc_loss() method + + Returns + ------- + grad : list + a list or nested structure of Tensors (or IndexedSlices, or None, + or CompositeTensor) representing the gradients for the + training_weights + loss_details : dict + Namespace of the breakdown of loss components + """ + with tf.device(device_name), tf.GradientTape( + watch_accessed_variables=False + ) as tape: + tape.watch(training_weights) + hi_res_exo = self.get_high_res_exo_input(hi_res_true) + hi_res_gen = self._tf_generate(low_res, hi_res_exo) + loss_out = self.calc_loss( + hi_res_true, hi_res_gen, **calc_loss_kwargs + ) + loss, loss_details = loss_out + if obs_data is not None: + loss_obs = self.calc_loss_obs(obs_data, hi_res_gen) + loss += loss_obs + loss_details['loss_obs'] = loss_obs + grad = tape.gradient(loss, training_weights) + return grad, loss_details + + @tf.function + def calc_loss_obs(self, obs_data, hi_res_gen): + """Calculate loss term for the observation data vs generated + high-resolution data + + Parameters + ---------- + obs_data : tf.Tensor | None + Observation data to use in additional content loss term. + This needs to have NaNs where there is no observation data. + hi_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + loss : tf.Tensor + 0D tensor of observation loss + """ + mask = tf.math.is_nan(obs_data) + return MeanAbsoluteError()( + obs_data[~mask], + hi_res_gen[..., : len(self.hr_out_features)][~mask], + ) diff --git a/tests/training/test_train_dual_with_obs.py b/tests/training/test_train_dual_with_obs.py index 0bf8244ee..2399b21ba 100644 --- a/tests/training/test_train_dual_with_obs.py +++ b/tests/training/test_train_dual_with_obs.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from sup3r.models import Sup3rGan +from sup3r.models import Sup3rGanWithObs from sup3r.preprocessing import ( Container, DataHandler, @@ -104,8 +104,8 @@ def test_train_h5_nc( assert not np.isnan(batch.obs).all() assert np.isnan(batch.obs).any() - Sup3rGan.seed() - model = Sup3rGan( + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' ) @@ -208,8 +208,8 @@ def test_train_coarse_h5( assert not np.isnan(batch.obs).all() assert np.isnan(batch.obs).any() - Sup3rGan.seed() - model = Sup3rGan( + Sup3rGanWithObs.seed() + model = Sup3rGanWithObs( fp_gen, fp_disc, learning_rate=lr, loss='MeanAbsoluteError' )