Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #165 from mys007/unet_nchannels_freedom
Browse files Browse the repository at this point in the history
UNet: Allow to freely define the number of channels per depth in subclasses
  • Loading branch information
DerThorsten authored Feb 20, 2019
2 parents 365dfd5 + 2d657e6 commit 50fa978
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
29 changes: 11 additions & 18 deletions inferno/extensions/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,18 @@ def __init__(self, in_channels, dim, out_channels=None, depth=3,
assert len(self.n_channels_per_output) == self._store_conv_down.count(True) + \
self._store_conv_up.count(True) + int(self._store_conv_bottom)

def _get_num_channels(self, depth):
assert depth > 0
return self.in_channels * self.gain**depth

def _init__downstream(self):
conv_down_ops = []
self._store_conv_down = []

current_in_channels = self.in_channels

for i in range(self.depth):
out_channels = current_in_channels * self.gain
out_channels = self._get_num_channels(i + 1)
op, return_op_res = self.conv_op_factory(in_channels=current_in_channels,
out_channels=out_channels,
part='down', index=i)
Expand All @@ -138,7 +142,7 @@ def _init__downstream(self):
self._store_conv_down.append(False)

# increase the number of channels
current_in_channels *= self.gain
current_in_channels = out_channels

# store as proper torch ModuleList
self._conv_down_ops = nn.ModuleList(conv_down_ops)
Expand All @@ -147,9 +151,7 @@ def _init__downstream(self):

def _init__bottom(self):

conv_up_ops = []

current_in_channels = self.in_channels* self.gain**self.depth
current_in_channels = self._get_num_channels(self.depth)

factory_res = self.conv_op_factory(in_channels=current_in_channels,
out_channels=current_in_channels, part='bottom', index=0)
Expand All @@ -163,12 +165,12 @@ def _init__bottom(self):

def _init__upstream(self):
conv_up_ops = []
current_in_channels = self.in_channels * self.gain**self.depth
current_in_channels = self._get_num_channels(self.depth)

for i in range(self.depth):
# the number of out channels (set to self.out_channels for last decoder)
out_channels = self.out_channels if i +1 == self.depth else\
current_in_channels // self.gain
out_channels = self.out_channels if i + 1 == self.depth else \
self._get_num_channels(self.depth - i - 1)

# if not residual we concat which needs twice as many channels
fac = 1 if self.residual else 2
Expand All @@ -186,7 +188,7 @@ def _init__upstream(self):
self._store_conv_up.append(False)

# decrease the number of input_channels
current_in_channels //= self.gain
current_in_channels = out_channels

# store as proper torch ModuleLis
self._conv_up_ops = nn.ModuleList(conv_up_ops)
Expand Down Expand Up @@ -311,15 +313,6 @@ def upsample_op_factory(self, index):\
return InfernoUpsample(**self._upsample_kwargs)
#return nn.Upsample(**self._upsample_kwargs)

def pre_conv_op_regularizer_factory(self, in_channels, out_channels, part, index):
if self.use_dropout and in_channels > 2:
return self._channel_dropout_op(x)
else:
return Identity()

def post_conv_op_regularizer_factory(self, in_channels, out_channels, part, index):
return Identity()

def conv_op_factory(self, in_channels, out_channels, part, index):
raise NotImplementedError("conv_op_factory need to be implemented by deriving class")

Expand Down
40 changes: 37 additions & 3 deletions tests/test_extensions/test_models/test_unet.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,58 @@
import unittest
import torch.cuda as cuda
from inferno.utils.model_utils import ModelTester
from inferno.utils.model_utils import ModelTester, MultiscaleModelTester
from inferno.extensions.models import UNet

class _MultiscaleUNet(UNet):
def conv_op_factory(self, in_channels, out_channels, part, index):
return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True

def forward(self, input):
x = self._initial_conv(input)
x = list(super(UNet, self).forward(x))
x[-1] = self._output(x[-1])
return tuple(x)


class UNetTest(unittest.TestCase):
def test_unet_2d(self):
from inferno.extensions.models import UNet
tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
if cuda.is_available():
tester.cuda()
tester(UNet(1, 1, dim=2, initial_features=32))

def test_unet_3d(self):
from inferno.extensions.models import UNet
tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
if cuda.is_available():
tester.cuda()
# test default unet 3d
tester(UNet(1, 1, dim=3, initial_features=8))

def test_monochannel_unet_3d(self):
nc = 2
class _UNetMonochannel(_MultiscaleUNet):
def _get_num_channels(self, depth):
return nc

shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4),
(1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)]
tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes)
if cuda.is_available():
tester.cuda()
tester(_UNetMonochannel(1, 1, dim=3, initial_features=8))

def test_inverse_pyramid_unet_2d(self):
class _UNetInversePyramid(_MultiscaleUNet):
def _get_num_channels(self, depth):
return [13, 12, 11][depth - 1]

shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8),
(1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)]
tester = MultiscaleModelTester((1, 1, 16, 64), shapes)
if cuda.is_available():
tester.cuda()
tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8))


if __name__ == '__main__':
unittest.main()

0 comments on commit 50fa978

Please sign in to comment.