Skip to content

Commit

Permalink
Sup3rGanWithObs model subclass. Other misc model refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jan 3, 2025
1 parent 7897998 commit 2b9c0d0
Show file tree
Hide file tree
Showing 6 changed files with 562 additions and 161 deletions.
1 change: 1 addition & 0 deletions sup3r/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan
from .solar_cc import SolarCC
from .surface import SurfaceSpatialMetModel
from .with_obs import Sup3rGanWithObs

SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan,
SolarMultiStepGan)
145 changes: 66 additions & 79 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat
from rex.utilities.utilities import safe_json_load
from tensorflow.keras import optimizers
from tensorflow.keras.losses import MeanAbsoluteError

import sup3r.utilities.loss_metrics
from sup3r.preprocessing.data_handlers import ExoData
Expand Down Expand Up @@ -420,6 +419,30 @@ def get_high_res_exo_input(self, high_res):
exo_data[feature] = exo_fdata
return exo_data

@tf.function
def _combine_loss_input(self, high_res_true, high_res_gen):
"""Combine exogenous feature data from high_res_true with high_res_gen
for loss calculation
Parameters
----------
high_res_true : tf.Tensor
Ground truth high resolution spatiotemporal data.
high_res_gen : tf.Tensor
Superresolved high resolution spatiotemporal data generated by the
generative model.
Returns
-------
high_res_gen : tf.Tensor
Same as input with exogenous data combined with high_res input
"""
if high_res_true.shape[-1] > high_res_gen.shape[-1]:
exo_dict = self.get_high_res_exo_input(high_res_true)
exo_data = [exo_dict[feat] for feat in self.hr_exo_features]
high_res_gen = tf.concat((high_res_gen, *exo_data), axis=-1)
return high_res_gen

@staticmethod
def get_loss_fun(loss):
"""Get the initialized loss function class from the sup3r loss library
Expand Down Expand Up @@ -717,25 +740,42 @@ def finish_epoch(

return stop

def _sum_parallel_grad(self, futures, start_time):
"""Sum gradient descent future results"""

# sum the gradients from each gpu to weight equally in
# optimizer momentum calculation
total_grad = None
for future in futures:
grad, loss_details = future.result()
if total_grad is None:
total_grad = grad
else:
for i, igrad in enumerate(grad):
total_grad[i] += igrad

msg = (
f'Finished {len(futures)} gradient descent steps on '
f'{len(self.gpu_list)} GPUs in {time.time() - start_time:.4f} '
'seconds'
)
logger.info(msg)
return total_grad, loss_details

def _get_parallel_grad(
self,
low_res,
hi_res_true,
training_weights,
obs_data=None,
**calc_loss_kwargs,
):
"""Compute gradient for one mini-batch of (low_res, hi_res_true)
across multiple GPUs"""

futures = []
start_time = time.time()
lr_chunks = np.array_split(low_res, len(self.gpu_list))
hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list))
obs_data_chunks = (
[None] * len(hr_true_chunks)
if obs_data is None
else np.array_split(obs_data, len(self.gpu_list))
)
split_mask = False
mask_chunks = None
if 'mask' in calc_loss_kwargs:
Expand All @@ -754,38 +794,17 @@ def _get_parallel_grad(
lr_chunks[i],
hr_true_chunks[i],
training_weights,
obs_data=obs_data_chunks[i],
device_name=f'/gpu:{i}',
**calc_loss_kwargs,
)
)

# sum the gradients from each gpu to weight equally in
# optimizer momentum calculation
total_grad = None
for future in futures:
grad, loss_details = future.result()
if total_grad is None:
total_grad = grad
else:
for i, igrad in enumerate(grad):
total_grad[i] += igrad

self.timer.stop()
logger.debug(
'Finished %s gradient descent steps on %s GPUs in %s',
len(futures),
len(self.gpu_list),
self.timer.elapsed_str,
)
return total_grad, loss_details
return self._sum_parallel_grad(futures, start_time=start_time)

def run_gradient_descent(
self,
low_res,
hi_res_true,
training_weights,
obs_data=None,
optimizer=None,
multi_gpu=False,
**calc_loss_kwargs,
Expand All @@ -806,10 +825,6 @@ def run_gradient_descent(
training_weights : list
A list of layer weights that are to-be-trained based on the
current loss weight values.
obs_data : tf.Tensor | None
Optional observation data to use in additional content loss term.
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
optimizer : tf.keras.optimizers.Optimizer
Optimizer class to use to update weights. This can be different if
you're training just the generator or one of the discriminator
Expand All @@ -829,32 +844,27 @@ def run_gradient_descent(
loss_details : dict
Namespace of the breakdown of loss components
"""

self.timer.start()
if optimizer is None:
optimizer = self.optimizer

if not multi_gpu or len(self.gpu_list) < 2:
start_time = time.time()
grad, loss_details = self.get_single_grad(
low_res,
hi_res_true,
training_weights,
obs_data=obs_data,
device_name=self.default_device,
**calc_loss_kwargs,
)
optimizer.apply_gradients(zip(grad, training_weights))
self.timer.stop()
logger.debug(
'Finished single gradient descent step in %s',
self.timer.elapsed_str,
)
msg = ('Finished single gradient descent step in '
f'{time.time() - start_time:.4f} seconds')
logger.debug(msg)
else:
total_grad, loss_details = self._get_parallel_grad(
low_res,
hi_res_true,
training_weights,
obs_data,
**calc_loss_kwargs,
)
optimizer.apply_gradients(zip(total_grad, training_weights))
Expand Down Expand Up @@ -1050,13 +1060,24 @@ def _tf_generate(self, low_res, hi_res_exo=None):

return hi_res

def _get_hr_exo_and_loss(
self,
low_res,
hi_res_true,
**calc_loss_kwargs,
):
"""Get high-resolution exogenous data, generate synthetic output, and
compute loss."""
hi_res_exo = self.get_high_res_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
return self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs)

@tf.function
def get_single_grad(
self,
low_res,
hi_res_true,
training_weights,
obs_data=None,
device_name=None,
**calc_loss_kwargs,
):
Expand All @@ -1076,10 +1097,6 @@ def get_single_grad(
training_weights : list
A list of layer weights that are to-be-trained based on the
current loss weight values.
obs_data : tf.Tensor | None
Optional observation data to use in additional content loss term.
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
device_name : None | str
Optional tensorflow device name for GPU placement. Note that if a
GPU is available, variables will be placed on that GPU even if
Expand All @@ -1100,16 +1117,10 @@ def get_single_grad(
watch_accessed_variables=False
) as tape:
tape.watch(training_weights)
hi_res_exo = self.get_high_res_exo_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss_out = self.calc_loss(
hi_res_true, hi_res_gen, **calc_loss_kwargs
loss_out = self._get_hr_exo_and_loss(
low_res, hi_res_true, **calc_loss_kwargs
)
loss, loss_details = loss_out
if obs_data is not None:
loss_obs = self.calc_loss_obs(obs_data, hi_res_gen)
loss += loss_obs
loss_details['loss_obs'] = loss_obs
grad = tape.gradient(loss, training_weights)
return grad, loss_details

Expand All @@ -1124,27 +1135,3 @@ def calc_loss(
):
"""Calculate the GAN loss function using generated and true high
resolution data."""

@tf.function
def calc_loss_obs(self, obs_data, hi_res_gen):
"""Calculate loss term for the observation data vs generated
high-resolution data
Parameters
----------
obs_data : tf.Tensor | None
Optional observation data to use in additional content loss term.
hi_res_gen : tf.Tensor
Superresolved high resolution spatiotemporal data generated by the
generative model.
Returns
-------
loss : tf.Tensor
0D tensor of observation loss
"""
mask = tf.math.is_nan(obs_data)
return MeanAbsoluteError()(
obs_data[~mask],
hi_res_gen[..., : len(self.hr_out_features)][~mask],
)
Loading

0 comments on commit 2b9c0d0

Please sign in to comment.