Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #198 from inferno-pytorch/apex
Browse files Browse the repository at this point in the history
Mixed precision training with apex
  • Loading branch information
constantinpape authored Feb 15, 2020
2 parents a75888e + 4c732db commit 8e055b0
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 25 deletions.
6 changes: 3 additions & 3 deletions inferno/extensions/containers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def is_node_in_graph(self, name):
-------
bool
"""
return name in self.graph.node
return name in self.graph.nodes

def is_source_node(self, name):
"""
Expand Down Expand Up @@ -187,7 +187,7 @@ def output_nodes(self):
list
A list of names (str) of the output nodes.
"""
return [name for name, node_attributes in self.graph.node.items()
return [name for name, node_attributes in self.graph.nodes.items()
if node_attributes.get('is_output_node', False)]

@property
Expand All @@ -201,7 +201,7 @@ def input_nodes(self):
list
A list of names (str) of the input nodes.
"""
return [name for name, node_attributes in self.graph.node.items()
return [name for name, node_attributes in self.graph.nodes.items()
if node_attributes.get('is_input_node', False)]

@property
Expand Down
3 changes: 2 additions & 1 deletion inferno/extensions/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .adam import Adam
from .annealed_adam import AnnealedAdam
from .annealed_adam import AnnealedAdam
from .ranger import Ranger, RangerQH, RangerVA
8 changes: 8 additions & 0 deletions inferno/extensions/optimizers/ranger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# easy support for additional ranger optimizers from
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
try:
from ranger import Ranger, RangerVA, RangerQH
except ImportError:
Ranger = None
RangerVA = None
RangerQH = None
14 changes: 12 additions & 2 deletions inferno/io/transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, *tensors, **transform_function_kwargs):
transformed = self.batch_function(tensors, **transform_function_kwargs)
return pyu.from_iterable(transformed)
elif hasattr(self, 'tensor_function'):
transformed = [self.tensor_function(tensor, **transform_function_kwargs)
transformed = [self._apply_tensor_function(tensor, **transform_function_kwargs)
if tensor_index in apply_to else tensor
for tensor_index, tensor in enumerate(tensors)]
return pyu.from_iterable(transformed)
Expand All @@ -77,9 +77,17 @@ def __call__(self, *tensors, **transform_function_kwargs):
else:
raise NotImplementedError

# noinspection PyUnresolvedReferences
def _apply_tensor_function(self, tensor, **transform_function_kwargs):
if isinstance(tensor, list):
return [self._apply_tensor_function(tens) for tens in tensor]
return self.tensor_function(tensor)

# noinspection PyUnresolvedReferences
def _apply_image_function(self, tensor, **transform_function_kwargs):
assert pyu.has_callable_attr(self, 'image_function')
if isinstance(tensor, list):
return [self._apply_image_function(tens) for tens in tensor]
# 2D case
if tensor.ndim == 4:
return np.array([np.array([self.image_function(image, **transform_function_kwargs)
Expand All @@ -106,6 +114,8 @@ def _apply_image_function(self, tensor, **transform_function_kwargs):
# noinspection PyUnresolvedReferences
def _apply_volume_function(self, tensor, **transform_function_kwargs):
assert pyu.has_callable_attr(self, 'volume_function')
if isinstance(tensor, list):
return [self._apply_volume_function(tens) for tens in tensor]
# 3D case
if tensor.ndim == 5:
# tensor is bczyx
Expand All @@ -125,7 +135,7 @@ def _apply_volume_function(self, tensor, **transform_function_kwargs):
# We're applying the volume function on the volume itself
return self.volume_function(tensor, **transform_function_kwargs)
else:
raise NotImplementedError
raise NotImplementedError("Volume function not implemented for ndim %i" % tensor.ndim)


class Compose(object):
Expand Down
4 changes: 2 additions & 2 deletions inferno/io/transform/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,5 +596,5 @@ def batch_function(self, image):
pad_r = image_shape - new_shape - pad_l
padding = [(0,0)] + list(zip(pad_l, pad_r))
img = np.pad(img, padding, 'constant', constant_values=self.pad_const)
seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const)
return img, seg
seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const)
return img, seg
3 changes: 2 additions & 1 deletion inferno/io/transform/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def volume_function(self, volume):
return volume


# TODO this is obsolete
class AdditiveRandomNoise3D(Transform):
""" Add gaussian noise to 3d volume
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self, sigma, mode='gaussian', **super_kwargs):
self.sigma = sigma

# TODO check if volume is tensor and use torch functions in that case
def volume_function(self, volume):
def tensor_function(self, volume):
volume += np.random.normal(loc=0, scale=self.sigma, size=volume.shape)
return volume

Expand Down
52 changes: 49 additions & 3 deletions inferno/io/volumetric/lazy_volume_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import os
import pickle
from concurrent import futures

# try to load io libraries (h5py and z5py)
try:
Expand All @@ -20,10 +22,39 @@
from ...utils import python_utils as pyu


# TODO support h5py as well
def filter_base_sequence(input_path, input_key,
window_size, stride,
filter_function, n_threads):
with z5py.File(input_path, 'r') as f:
ds = f[input_key]
shape = list(ds.shape)
sequence = vu.slidingwindowslices(shape=shape,
window_size=window_size,
strides=stride,
shuffle=True,
add_overhanging=True)

def check_slice(slice_id, slice_):
print("Checking slice_id", slice_id)
data = ds[slice_]
if filter_function(data):
return None
else:
return slice_

with futures.ThreadPoolExecutor(n_threads) as tp:
tasks = [tp.submit(check_slice, slice_id, slice_) for slice_id, slice_ in enumerate(sequence)]
filtered_sequence = [t.result() for t in tasks]

filtered_sequence = [seq for seq in filtered_sequence if seq is not None]
return filtered_sequence


class LazyVolumeLoaderBase(SyncableDataset):
def __init__(self, dataset, window_size, stride, downsampling_ratio=None, padding=None,
padding_mode='reflect', transforms=None, return_index_spec=False, name=None,
data_slice=None):
data_slice=None, base_sequence=None):
super(LazyVolumeLoaderBase, self).__init__()
assert len(window_size) == dataset.ndim, "%i, %i" % (len(window_size), dataset.ndim)
assert len(stride) == dataset.ndim
Expand Down Expand Up @@ -58,7 +89,22 @@ def __init__(self, dataset, window_size, stride, downsampling_ratio=None, paddin
else:
raise NotImplementedError

self.base_sequence = self.make_sliding_windows()
if base_sequence is None:
self.base_sequence = self.make_sliding_windows()
else:
self.base_sequence = self.load_base_sequence(base_sequence)

@staticmethod
def load_base_sequence(base_sequence):
if isinstance(base_sequence, (list, tuple)):
return base_sequence
elif isinstance(base_sequence, str):
assert os.path.exists(base_sequence)
with open(base_sequence, 'rb') as f:
base_sequence = pickle.load(f)
return base_sequence
else:
raise ValueError("Unsupported base_sequence format, must be either listlike or str")

def normalize_slice(self, data_slice):
if data_slice is None:
Expand Down Expand Up @@ -185,7 +231,7 @@ def __init__(self, file_impl, path,
assert os.path.exists(path), path
self.path = path
else:
raise NotImplementedError
raise NotImplementedError("Not implemented for type %s" % type(path))

if isinstance(path_in_file, dict):
assert name is not None
Expand Down
3 changes: 2 additions & 1 deletion inferno/io/volumetric/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def pad_volume(self, padding=None):
assert_(all(isinstance(pad, (int, tuple, list)) for pad in self.padding),\
"Expect int or iterable", TypeError)
self.padding = [[pad, pad] if isinstance(pad, int) else pad for pad in self.padding]
print(self.volume.shape)
self.volume = np.pad(self.volume,
pad_width=self.padding,
mode=self.padding_mode)
Expand Down Expand Up @@ -228,7 +229,7 @@ def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=No
if self.data_slice is not None and slicing_config_for_name.get('is_multichannel', False):
self.data_slice = (slice(None),) + self.data_slice

assert 'window_size' in slicing_config_for_name
assert 'window_size' in slicing_config_for_name, str(slicing_config_for_name)
assert 'stride' in slicing_config_for_name

# Read in volume from file (can be hdf5, n5 or zarr)
Expand Down
11 changes: 6 additions & 5 deletions inferno/io/volumetric/volumetric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def dimension_window(start, stop, wsize, stride, dimsize, ds_dim):
# otherwise predict the whole volume
if dataslice is not None:
assert len(dataslice) == dim, "Dataslice must be a tuple with len = data dimension."
starts = [sl.start for sl in dataslice]
stops = [sl.stop - wsize for sl, wsize in zip(dataslice, window_size)]
starts = [0 if sl.start is None else sl.start for sl in dataslice]
stops = [sh - wsize if sl.stop is None else sl.stop - wsize
for sl, wsize, sh in zip(dataslice, window_size, shape)]
else:
starts = dim * [0]
stops = [dimsize - wsize if wsize != dimsize else dimsize
for dimsize, wsize in zip(shape, window_size)]
stops = [dimsize - wsize if wsize != dimsize else dimsize
for dimsize, wsize in zip(shape, window_size)]

assert all(stp > strt for strt, stp in zip(starts, stops)),\
"%s, %s" % (str(starts), str(stops))
Expand Down Expand Up @@ -128,7 +129,7 @@ def _to_list(x):
nslices = [_1Dwindow(startmin, startmax, nhoodsiz, st, dsample, datalen, shuffle) if windowspec == 'x'
else [slice(ws, ws + 1) for ws in _to_list(windowspec)]
for startmin, startmax, datalen, nhoodsiz, st, windowspec, dsample in zip(startmins, startmaxs, shape,
nhoodsize, stride, window, ds)]
nhoodsize, stride, window, ds)]

return it.product(*nslices)

Expand Down
54 changes: 50 additions & 4 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
from .callbacks import Console
from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError

# NOTE for distributed training, we might also need
# from apex.parallel import DistributedDataParallel as DDP
# but I don't know where exactly to put it.
try:
from apex import amp
except ImportError:
amp = None


class Trainer(object):
"""A basic trainer.
Expand Down Expand Up @@ -126,10 +134,44 @@ def __init__(self, model=None):
# Print console
self._console = Console()

# Train with mixed precision, only works
# if we have apex
self._mixed_precision = False
self._apex_opt_level = 'O1'

# Public
if model is not None:
self.model = model

@property
def mixed_precision(self):
return self._mixed_precision

# this needs to be called after model and optimizer are set
@mixed_precision.setter
def mixed_precision(self, mp):
if mp:
assert_(amp is not None, "Cannot use mixed precision training without apex library", RuntimeError)
assert_(self.model is not None and self._optimizer is not None,
"Model and optimizer need to be set before activating mixed precision", RuntimeError)
# in order to support BCE loss
amp.register_float_function(torch, 'sigmoid')
# For now, we don't allow to set 'keep_batchnorm' and 'loss_scale'
self.model, self._optimizer = amp.initialize(self.model, self._optimizer,
opt_level=self._apex_opt_level,
keep_batchnorm_fp32=None)
self._mixed_precision = mp

@property
def apex_opt_level(self):
return self._apex_opt_level

@apex_opt_level.setter
def apex_opt_level(self, opt_level):
assert_(opt_level in ('O0', 'O1', 'O2', 'O3'),
"Invalid optimization level", ValueError)
self._apex_opt_level = opt_level

@property
def console(self):
"""Get the current console."""
Expand Down Expand Up @@ -1368,17 +1410,21 @@ def apply_model_and_loss(self, inputs, target, backward=True, mode=None):
kwargs['trainer'] = self
if mode == 'train':
loss = self.criterion(prediction, target, **kwargs) \
if len(target) != 0 else self.criterion(prediction, **kwargs)
if len(target) != 0 else self.criterion(prediction, **kwargs)
elif mode == 'eval':
loss = self.validation_criterion(prediction, target, **kwargs) \
if len(target) != 0 else self.validation_criterion(prediction, **kwargs)
if len(target) != 0 else self.validation_criterion(prediction, **kwargs)
else:
raise ValueError
if backward:
# Backprop if required
# retain_graph option is needed for some custom
# loss functions like malis, False per default
loss.backward(retain_graph=self.retain_graph)
if self.mixed_precision:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward(retain_graph=self.retain_graph)
else:
loss.backward(retain_graph=self.retain_graph)
return prediction, loss

def train_for(self, num_iterations=None, break_callback=None):
Expand Down Expand Up @@ -1676,7 +1722,7 @@ def load(self, from_directory=None, best=False, filename=None, map_location=None
'best_checkpoint.pytorch'.
filename : str
Overrides the default filename.
device : function, torch.device, string or a dict
map_location : function, torch.device, string or a dict
Specify how to remap storage locations.
Returns
Expand Down
3 changes: 2 additions & 1 deletion inferno/trainers/callbacks/essentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,10 @@ def norm_or_value(self):
def after_model_and_loss_is_applied(self, **_):
tu.clip_gradients_(self.trainer.model.parameters(), self.mode, self.norm_or_value)


class GarbageCollection(Callback):
"""
Callback that triggers garbage collection at the end of every
Callback that triggers garbage collection at the end of every
training iteration in order to reduce the memory footprint of training
"""

Expand Down
3 changes: 2 additions & 1 deletion inferno/trainers/callbacks/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,10 @@ def end_of_validation_run(self, **_):

@staticmethod
def is_significantly_less_than(x, y, min_relative_delta):
eps = 1.e-6
if x > y:
return False
relative_delta = abs(y - x) / abs(y)
relative_delta = abs(y - x) / (abs(y) + eps)
return relative_delta > min_relative_delta


Expand Down
2 changes: 1 addition & 1 deletion inferno/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def yaml2dict(path):
# Forgivable mistake that path is a dict already
return path
with open(path, 'r') as f:
readict = yaml.load(f)
readict = yaml.load(f, Loader=yaml.FullLoader)
return readict


Expand Down

0 comments on commit 8e055b0

Please sign in to comment.