diff --git a/inferno/extensions/layers/convolutional.py b/inferno/extensions/layers/convolutional.py index 01cbf0c4..1a1ff493 100755 --- a/inferno/extensions/layers/convolutional.py +++ b/inferno/extensions/layers/convolutional.py @@ -1,77 +1,109 @@ import torch.nn as nn -from ..initializers import OrthogonalWeightsZeroBias, KaimingNormalWeightsZeroBias, \ - SELUWeightsZeroBias +import sys +import functools +from ..initializers import ( + OrthogonalWeightsZeroBias, + KaimingNormalWeightsZeroBias, + SELUWeightsZeroBias, +) from ..initializers import Initializer +from .normalization import BatchNormND from .activations import SELU from ...utils.exceptions import assert_, ShapeError +from ...utils.partial_cls import register_partial_cls - -__all__ = ['ConvActivation', - 'ConvELU2D', 'ConvELU3D', - 'ConvSigmoid2D', 'ConvSigmoid3D', - 'DeconvELU2D', 'DeconvELU3D', - 'StridedConvELU2D', 'StridedConvELU3D', - 'DilatedConvELU2D', 'DilatedConvELU3D', - 'Conv2D', 'Conv3D', - 'BNReLUConv2D', 'BNReLUConv3D', - 'BNReLUDepthwiseConv2D', - 'ConvSELU2D', 'ConvSELU3D', - 'ConvReLU2D', 'ConvReLU3D', - 'BNReLUDilatedConv2D', 'DilatedConv2D', - 'GlobalConv2D'] +# we append to this later on +__all__ = [ + "GlobalConv2D", +] _all = __all__ +register_partial_cls_here = functools.partial(register_partial_cls, module=__name__) + class ConvActivation(nn.Module): """Convolutional layer with 'SAME' padding by default followed by an activation.""" - def __init__(self, in_channels, out_channels, kernel_size, dim, activation, - stride=1, dilation=1, groups=None, depthwise=False, bias=True, - deconv=False, initialization=None, valid_conv=False): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dim, + activation, + stride=1, + dilation=1, + groups=None, + depthwise=False, + bias=True, + deconv=False, + initialization=None, + valid_conv=False, + ): super(ConvActivation, self).__init__() # Validate dim - assert_(dim in [2, 3], "`dim` must be one of [2, 3], got {}.".format(dim), ShapeError) + assert_( + dim in [1, 2, 3], + "`dim` must be one of [1, 2, 3], got {}.".format(dim), + ShapeError, + ) self.dim = dim # Check if depthwise if depthwise: - assert_(in_channels == out_channels, - "For depthwise convolutions, number of input channels (given: {}) " - "must equal the number of output channels (given {})." - .format(in_channels, out_channels), - ValueError) - assert_(groups is None or groups == in_channels, - "For depthwise convolutions, groups (given: {}) must " - "equal the number of channels (given: {}).".format(groups, in_channels)) + + # We know that in_channels == out_channels, but we also want a consistent API. + # As a compromise, we allow that out_channels be None or 'auto'. + out_channels = in_channels if out_channels in [None, "auto"] else out_channel + assert_( + in_channels == out_channels, + "For depthwise convolutions, number of input channels (given: {}) " + "must equal the number of output channels (given {}).".format( + in_channels, out_channels + ), + ValueError, + ) + assert_( + groups is None or groups == in_channels, + "For depthwise convolutions, groups (given: {}) must " + "equal the number of channels (given: {}).".format(groups, in_channels), + ) groups = in_channels else: groups = 1 if groups is None else groups self.depthwise = depthwise if valid_conv: - self.conv = getattr(nn, 'Conv{}d'.format(self.dim))(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - groups=groups, - bias=bias) + self.conv = getattr(nn, "Conv{}d".format(self.dim))( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) elif not deconv: # Get padding padding = self.get_padding(kernel_size, dilation) - self.conv = getattr(nn, 'Conv{}d'.format(self.dim))(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=padding, - stride=stride, - dilation=dilation, - groups=groups, - bias=bias) + self.conv = getattr(nn, "Conv{}d".format(self.dim))( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) else: - self.conv = getattr(nn, 'ConvTranspose{}d'.format(self.dim))(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - groups=groups, - bias=bias) + self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + ) if initialization is None: pass elif isinstance(initialization, Initializer): @@ -114,253 +146,14 @@ def _get_padding(self, _kernel_size, _dilation): def get_padding(self, kernel_size, dilation): kernel_size = self._pair_or_triplet(kernel_size) dilation = self._pair_or_triplet(dilation) - padding = [self._get_padding(_kernel_size, _dilation) - for _kernel_size, _dilation in zip(kernel_size, dilation)] + padding = [ + self._get_padding(_kernel_size, _dilation) + for _kernel_size, _dilation in zip(kernel_size, dilation) + ] return tuple(padding) - -class ConvELU2D(ConvActivation): - """2D Convolutional layer with 'SAME' padding, ELU and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ConvELU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - activation='ELU', - initialization=OrthogonalWeightsZeroBias()) - - -class ConvELU3D(ConvActivation): - """3D Convolutional layer with 'SAME' padding, ELU and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ConvELU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - activation='ELU', - initialization=OrthogonalWeightsZeroBias()) - -class ValidConvELU2D(ConvActivation): - """2D Convolutional layer with 'VALID' padding, ELU and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ValidConvELU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - activation='ELU', - valid_conv=True, - initialization=OrthogonalWeightsZeroBias()) - -class ValidConvELU3D(ConvActivation): - """3D Convolutional layer with 'VALID' padding, ELU and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ValidConvELU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - activation='ELU', - valid_conv=True, - initialization=OrthogonalWeightsZeroBias()) - -class ConvSigmoid2D(ConvActivation): - """2D Convolutional layer with 'SAME' padding, Sigmoid and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ConvSigmoid2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - activation='Sigmoid', - initialization=OrthogonalWeightsZeroBias()) - - -class ConvSigmoid3D(ConvActivation): - """3D Convolutional layer with 'SAME' padding, Sigmoid and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ConvSigmoid3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - activation='Sigmoid', - initialization=OrthogonalWeightsZeroBias()) - - -class DeconvELU2D(ConvActivation): - """2D deconvolutional layer with ELU and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): - super(DeconvELU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - activation='ELU', - deconv=True, - stride=stride, - initialization=OrthogonalWeightsZeroBias()) - - -class DeconvELU3D(ConvActivation): - """3D deconvolutional layer with ELU and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): - super(DeconvELU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - activation='ELU', - deconv=True, - stride=2, - initialization=OrthogonalWeightsZeroBias()) - - -class StridedConvELU2D(ConvActivation): - """ - 2D strided convolutional layer with 'SAME' padding, ELU and orthogonal - weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=2): - super(StridedConvELU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dim=2, - activation='ELU', - initialization=OrthogonalWeightsZeroBias()) - - -class StridedConvELU3D(ConvActivation): - """ - 2D strided convolutional layer with 'SAME' padding, ELU and orthogonal - weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=2): - super(StridedConvELU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dim=3, - activation='ELU', - initialization=OrthogonalWeightsZeroBias()) - - -class DilatedConvELU2D(ConvActivation): - """ - 2D dilated convolutional layer with 'SAME' padding, ELU and orthogonal - weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, dilation=2): - super(DilatedConvELU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - dim=2, - activation='ELU', - initialization=OrthogonalWeightsZeroBias()) - - -class DilatedConvELU3D(ConvActivation): - """ - 3D dilated convolutional layer with 'SAME' padding, ELU and orthogonal - weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, dilation=2): - super(DilatedConvELU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - dim=3, - activation='ELU', - initialization=OrthogonalWeightsZeroBias()) - -class DilatedConv2D(ConvActivation): - """2D dilated convolutional layer with 'SAME' padding, no activation and orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size, dilation=2): - super(DilatedConv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - dim=2, - activation=None, - initialization=OrthogonalWeightsZeroBias()) - - -class ConvReLU2D(ConvActivation): - """2D Convolutional layer with 'SAME' padding, ReLU and Kaiming normal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ConvReLU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - activation='ReLU', - initialization=KaimingNormalWeightsZeroBias()) - - -class ConvReLU3D(ConvActivation): - """3D Convolutional layer with 'SAME' padding, ReLU and Kaiming normal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - super(ConvReLU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - activation='ReLU', - initialization=KaimingNormalWeightsZeroBias()) - - -class Conv2D(ConvActivation): - """ - 2D convolutional layer with same padding and orthogonal weight initialization. - By default, this layer does not apply an activation function. - """ - def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, - activation=None): - super(Conv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - stride=stride, - dim=2, - activation=activation, - initialization=OrthogonalWeightsZeroBias()) - - -class Conv3D(ConvActivation): - """ - 3D convolutional layer with same padding and orthogonal weight initialization. - By default, this layer does not apply an activation function. - """ - def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, - activation=None): - super(Conv3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - stride=stride, - dim=3, - activation=activation, - initialization=OrthogonalWeightsZeroBias()) - - -class Deconv2D(ConvActivation): - """2D deconvolutional layer with orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): - super(Deconv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - stride=stride, - activation=None, - deconv=True, - initialization=OrthogonalWeightsZeroBias()) - - -class Deconv3D(ConvActivation): - """2D deconvolutional layer with orthogonal weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): - super(Deconv3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - stride=stride, - activation=None, - deconv=True, - initialization=OrthogonalWeightsZeroBias()) +# for consistency +ConvActivationND = ConvActivation # noinspection PyUnresolvedReferences @@ -371,164 +164,150 @@ def forward(self, input): conved = self.conv(activated) return conved - -class BNReLUConv2D(_BNReLUSomeConv, ConvActivation): - """ - 2D BN-ReLU-Conv layer with 'SAME' padding and He weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1): - super(BNReLUConv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - stride=stride, - activation=nn.ReLU(inplace=True), - initialization=KaimingNormalWeightsZeroBias(0)) - self.batchnorm = nn.BatchNorm2d(in_channels) - - -class BNReLUDilatedConv2D(_BNReLUSomeConv,ConvActivation): - """ - 2D dilated convolutional layer with 'SAME' padding, Batch norm, Relu and He - weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, dilation=2): - super(BNReLUDilatedConv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dilation=dilation, - dim=2, - activation=nn.ReLU(inplace=True), - initialization=KaimingNormalWeightsZeroBias(0)) - self.batchnorm = nn.BatchNorm2d(in_channels) - - -class BNReLUConv3D(_BNReLUSomeConv, ConvActivation): - """ - 3D BN-ReLU-Conv layer with 'SAME' padding and He weight initialization. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1): - super(BNReLUConv3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - stride=stride, - activation=nn.ReLU(inplace=True), - initialization=KaimingNormalWeightsZeroBias(0)) - self.batchnorm = nn.BatchNorm3d(in_channels) - - -class BNReLUDeconv2D(_BNReLUSomeConv, ConvActivation): - """ - 2D BN-ReLU-Deconv layer with He weight initialization and (default) stride 2. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=2): - super(BNReLUDeconv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - stride=stride, - deconv=True, - activation=nn.ReLU(inplace=True), - initialization=KaimingNormalWeightsZeroBias(0)) - self.batchnorm = nn.BatchNorm2d(in_channels) - - -class BNReLUDeconv3D(_BNReLUSomeConv, ConvActivation): - """ - 3D BN-ReLU-Deconv layer with He weight initialization and (default) stride 2. - """ - def __init__(self, in_channels, out_channels, kernel_size, stride=2): - super(BNReLUDeconv3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - stride=stride, - deconv=True, - activation=nn.ReLU(inplace=True), - initialization=KaimingNormalWeightsZeroBias(0)) - self.batchnorm = nn.BatchNorm2d(in_channels) - - -class BNReLUDepthwiseConv2D(_BNReLUSomeConv, ConvActivation): - """ - 2D BN-ReLU-Conv layer with 'SAME' padding, He weight initialization and depthwise convolution. - Note that depthwise convolutions require `in_channels == out_channels`. - """ - def __init__(self, in_channels, out_channels, kernel_size): - # We know that in_channels == out_channels, but we also want a consistent API. - # As a compromise, we allow that out_channels be None or 'auto'. - out_channels = in_channels if out_channels in [None, 'auto'] else out_channels - super(BNReLUDepthwiseConv2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - depthwise=True, - activation=nn.ReLU(inplace=True), - initialization=KaimingNormalWeightsZeroBias(0)) - self.batchnorm = nn.BatchNorm2d(in_channels) - - -class ConvSELU2D(ConvActivation): - """2D Convolutional layer with SELU activation and the appropriate weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - if hasattr(nn, 'SELU'): - # Pytorch 0.2: Use built in SELU +class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation): + def __init__(self, in_channels, out_channels, kernel_size, dim, stride=1, dilation=1, deconv=False): + + super(BNReLUConvBaseND, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dim=dim, + stride=stride, + activation=nn.ReLU(inplace=True), + dilation=dilation, + deconv=deconv, + initialization=KaimingNormalWeightsZeroBias(0), + ) + self.batchnorm = BatchNormND(dim, in_channels) + + +def _register_conv_cls(conv_name, fix=None, default=None): + if fix is None: + fix = {} + if default is None: + default = {} + + # simple conv activation + activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""] + init_map = { + "ReLU": KaimingNormalWeightsZeroBias, + "SELU": SELUWeightsZeroBias + } + for activation_str in activations: + cls_name = cls_name = "{}{}ND".format(conv_name,activation_str) + __all__.append(cls_name) + initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias) + if activation_str == "": + activation = None + _fix = {**fix} + _default = {'activation':None} + elif activation_str == "SELU": activation = nn.SELU(inplace=True) + _fix={**fix, 'activation':activation} + _default = {**default} else: - # Pytorch < 0.1.12: Use handmade SELU - activation = SELU() - super(ConvSELU2D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=2, - activation=activation, - initialization=SELUWeightsZeroBias()) - - -class ConvSELU3D(ConvActivation): - """3D Convolutional layer with SELU activation and the appropriate weight initialization.""" - def __init__(self, in_channels, out_channels, kernel_size): - if hasattr(nn, 'SELU'): - # Pytorch 0.2: Use built in SELU - activation = nn.SELU(inplace=True) - else: - # Pytorch < 0.1.12: Use handmade SELU - activation = SELU() - super(ConvSELU3D, self).__init__(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - dim=3, - activation=activation, - initialization=SELUWeightsZeroBias()) + activation = activation_str + _fix={**fix, 'activation':activation} + _default = {**default} + + register_partial_cls_here(ConvActivation, cls_name, + fix=_fix, + default={**_default, 'initialization':initialization_cls()} + ) + for dim in [1, 2, 3]: + cls_name = "{}{}{}D".format(conv_name,activation_str, dim) + __all__.append(cls_name) + register_partial_cls_here(ConvActivation, cls_name, + fix={**_fix, 'dim':dim}, + default={**_default, 'initialization':initialization_cls()} + ) + +def _register_bnr_conv_cls(conv_name, fix=None, default=None): + if fix is None: + fix = {} + if default is None: + default = {} + for dim in [1, 2, 3]: + + cls_name = "BNReLU{}ND".format(conv_name) + __all__.append(cls_name) + register_partial_cls_here(BNReLUConvBaseND, cls_name,fix=fix,default=default) + + for dim in [1, 2, 3]: + cls_name = "BNReLU{}{}D".format(conv_name, dim) + __all__.append(cls_name) + register_partial_cls_here(BNReLUConvBaseND, cls_name, + fix={**fix, 'dim':dim}, + default=default) + +# conv classes +_register_conv_cls("Conv") +_register_conv_cls("ValidConv", fix=dict(valid_conv=True)) +_register_conv_cls("Deconv", fix=dict(deconv=True), default=dict(kernel_size=2, stride=2)) +_register_conv_cls("StridedConv", default=dict(stride=2)) +_register_conv_cls("DilatedConv", fix=dict(dilation=2)) +_register_conv_cls("DepthwiseConv", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto')) + +# BatchNormRelu classes +_register_bnr_conv_cls("Conv", fix=dict(deconv=False)) +_register_bnr_conv_cls("Deconv", fix=dict(deconv=True)) +_register_bnr_conv_cls("StridedConv", default=dict(stride=2)) +_register_bnr_conv_cls("DilatedConv", default=dict(dilation=2)) +_register_bnr_conv_cls("DepthwiseConv", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto')) + +del _register_conv_cls +del _register_bnr_conv_cls + + class GlobalConv2D(nn.Module): """From https://arxiv.org/pdf/1703.02719.pdf Main idea: we can have a bigger kernel size computationally acceptable if we separate 2D-conv in 2 1D-convs """ - def __init__(self, in_channels, out_channels, kernel_size, local_conv_type, - activation=None, use_BN=False, **kwargs): + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + local_conv_type, + activation=None, + use_BN=False, + **kwargs + ): super(GlobalConv2D, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size assert isinstance(kernel_size, (int, list, tuple)) if isinstance(kernel_size, int): - kernel_size = (kernel_size,)*2 - self.kwargs=kwargs - self.conv1a = local_conv_type(in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=(kernel_size[0], 1), **kwargs) - self.conv1b = local_conv_type(in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=(1, kernel_size[1]), **kwargs) - self.conv2a = local_conv_type(in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=(1, kernel_size[1]), **kwargs) - self.conv2b = local_conv_type(in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=(kernel_size[0], 1), **kwargs) + kernel_size = (kernel_size,) * 2 + self.kwargs = kwargs + self.conv1a = local_conv_type( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=(kernel_size[0], 1), + **kwargs + ) + self.conv1b = local_conv_type( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=(1, kernel_size[1]), + **kwargs + ) + self.conv2a = local_conv_type( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=(1, kernel_size[1]), + **kwargs + ) + self.conv2b = local_conv_type( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=(kernel_size[0], 1), + **kwargs + ) if use_BN: self.batchnorm = nn.BatchNorm2d(self.out_channels) else: @@ -540,7 +319,7 @@ def forward(self, input_): out1 = self.conv1b(out1) out2 = self.conv2a(input_) out2 = self.conv2b(out2) - out = out1.add(1,out2) + out = out1.add(1, out2) if self.activation is not None: out = self.activation(out) if self.batchnorm is not None: diff --git a/inferno/extensions/layers/normalization.py b/inferno/extensions/layers/normalization.py new file mode 100644 index 00000000..fd00476b --- /dev/null +++ b/inferno/extensions/layers/normalization.py @@ -0,0 +1,14 @@ +import torch.nn as nn + + +class BatchNormND(nn.Module): + def __init__(self, dim, num_features, + eps=1e-5, momentum=0.1, + affine=True,track_running_stats=True): + super(BatchNormND, self).__init__() + assert dim in [1, 2, 3] + self.bn = getattr(nn, 'BatchNorm{}d'.format(dim))(num_features=num_features, + eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + + def forward(self, x): + return self.bn(x) \ No newline at end of file diff --git a/inferno/utils/partial_cls.py b/inferno/utils/partial_cls.py new file mode 100644 index 00000000..68dbbc08 --- /dev/null +++ b/inferno/utils/partial_cls.py @@ -0,0 +1,141 @@ +import functools +import sys +import types +import inspect + + +__all__ = [ + 'partial_cls', + 'register_partial_cls' +] + + +def partial_cls(base_cls, name, module, fix=None, default=None): + + # helper function + def insert_if_not_present(dict_a, dict_b): + for kw,val in dict_b.items(): + if kw not in dict_a: + dict_a[kw] = val + return dict_a + + # helper function + def insert_call_if_present(dict_a, dict_b, callback): + for kw,val in dict_b.items(): + if kw not in dict_a: + dict_a[kw] = val + else: + callback(kw) + return dict_a + + # helper class + class PartialCls(object): + def __init__(self, base_cls, name, module, fix=None, default=None): + + self.base_cls = base_cls + self.name = name + self.module = module + self.fix = [fix, {}][fix is None] + self.default = [default, {}][default is None] + + if self.fix.keys() & self.default.keys(): + raise TypeError('fix and default share keys') + + # remove binded kw + self._allowed_kw = self._get_allowed_kw() + + def _get_allowed_kw(self): + + + argspec = inspect.getfullargspec(base_cls.__init__) + args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = argspec + + if varargs is not None: + raise TypeError('partial_cls can only be used if __init__ has no varargs') + + if varkw is not None: + raise TypeError('partial_cls can only be used if __init__ has no varkw') + + if kwonlyargs is not None and kwonlyargs != []: + raise TypeError('partial_cls can only be used without kwonlyargs') + + if args is None or len(args) < 1: + raise TypeError('seems like self is missing') + + + return [kw for kw in args[1:] if kw not in self.fix] + + + def _build_kw(self, args, kwargs): + # handle *args + if len(args) > len(self._allowed_kw): + raise TypeError("to many arguments") + + all_args = {} + for arg, akw in zip(args, self._allowed_kw): + all_args[akw] = arg + + # handle **kwargs + intersection = self.fix.keys() & kwargs.keys() + if len(intersection) >= 1: + kw = intersection.pop() + raise TypeError("`{}.__init__` got unexpected keyword argument '{}'".format(name, kw)) + + def raise_cb(kw): + raise TypeError("{}.__init__ got multiple values for argument '{}'".format(name, kw)) + all_args = insert_call_if_present(all_args, kwargs, raise_cb) + + # handle fixed arguments + def raise_cb(kw): + raise TypeError() + all_args = insert_call_if_present(all_args, self.fix, raise_cb) + + # handle defaults + all_args = insert_if_not_present(all_args, self.default) + + # handle fixed + all_args.update(self.fix) + + return all_args + + def build_cls(self): + + def new_init(self_of_new_cls, *args, **kwargs): + combined_args = self._build_kw(args=args, kwargs=kwargs) + + #call base cls init + super(self_of_new_cls.__class__, self_of_new_cls).__init__(**combined_args) + + return type(name, (self.base_cls,), { + '__module__': self.module, + '__init__' : new_init + }) + return cls + + + return PartialCls(base_cls=base_cls, name=name, module=module, + fix=fix, default=default).build_cls() + + +def register_partial_cls(base_cls, name, module, fix=None, default=None): + module_dict = sys.modules[module].__dict__ + generatedClass = partial_cls(base_cls=base_cls,name=name, module=module, + fix=fix, default=default) + module_dict[generatedClass.__name__] = generatedClass + del generatedClass + + +if __name__ == "__main__": + + class Conv(object): + def __init__(self, dim, activation, stride=1): + print(f"dim {dim} act {activation} stride {stride}") + + + Conv2D = partial_cls(Conv,'Conv2D',__name__, fix=dict(dim=2), default=dict(stride=2)) + + + #obj = Conv2D(activation='a') + #obj = Conv2D('a',activation='a', stride=3) + obj = Conv2D('fu','bar') + diff --git a/tests/test_inferno.py b/tests/test_inferno.py index d32da4f1..677d2ca9 100755 --- a/tests/test_inferno.py +++ b/tests/test_inferno.py @@ -132,8 +132,10 @@ def test_training_cpu(self): .bind_loader('train', self.train_loader)\ .bind_loader('validate', self.validate_loader) # Go + trainer.pickle_module = 'dill' trainer.fit() + if __name__ == '__main__': unittest.main() diff --git a/tests/test_utils/test_partial_cls.py b/tests/test_utils/test_partial_cls.py new file mode 100644 index 00000000..a6ff7dc0 --- /dev/null +++ b/tests/test_utils/test_partial_cls.py @@ -0,0 +1,137 @@ +import unittest +import inferno.utils.model_utils as mu +from inferno.utils.partial_cls import register_partial_cls +import torch +import torch.nn as nn + + +class TestCls(object): + def __init__(self, a, b, c=1, d=2): + self.a = a + self.b = b + self.c = c + self.d = d + +class PartialClsTester(unittest.TestCase): + + def test_partial_cls(self): + register_partial_cls(TestCls, 'TestA', + fix=dict(a='a'), + default=dict(b='b'), + module=__name__ + ) + assert 'TestA' in globals() + + inst = TestA() + assert inst.a == 'a' + assert inst.b == 'b' + assert inst.c == 1 + assert inst.d == 2 + + inst = TestA('fu','bar','fubar') + assert inst.a == 'a' + assert inst.b == 'fu' + assert inst.c == 'bar' + assert inst.d == 'fubar' + + with self.assertRaises(TypeError): + inst = TestA(a=2) + + def test_update_existing_default_cls(self): + register_partial_cls(TestCls, 'TestA', + fix=dict(a='a'), + default=dict(d=3), + module=__name__ + ) + assert 'TestA' in globals() + + inst = TestA(42) + assert inst.a == 'a' + assert inst.b == 42 + assert inst.c == 1 + assert inst.d == 3 + + with self.assertRaises(TypeError): + inst = TestA() + + def test_fix_nothing(self): + register_partial_cls(TestCls, 'TestA', + module=__name__ + ) + assert 'TestA' in globals() + + inst = TestA(1,2,3,4) + assert inst.a == 1 + assert inst.b == 2 + assert inst.c == 3 + assert inst.d == 4 + + with self.assertRaises(TypeError): + inst = TestA() + + def test_fix_all(self): + register_partial_cls(TestCls, 'TestA', + module=__name__, + fix=dict(a=4, b=3, c=2, d=1) + ) + assert 'TestA' in globals() + + inst = TestA() + assert inst.a == 4 + assert inst.b == 3 + assert inst.c == 2 + assert inst.d == 1 + + with self.assertRaises(TypeError): + inst = TestA('a') + + with self.assertRaises(TypeError): + inst = TestA(a=1) + with self.assertRaises(TypeError): + inst = TestA(b=1) + with self.assertRaises(TypeError): + inst = TestA(c=1) + with self.assertRaises(TypeError): + inst = TestA(d=1) + + + def test_default_all(self): + register_partial_cls(TestCls, 'TestA', + module=__name__, + default=dict(a=4, b=3, c=2, d=1) + ) + assert 'TestA' in globals() + + inst = TestA() + assert inst.a == 4 + assert inst.b == 3 + assert inst.c == 2 + assert inst.d == 1 + + + inst = TestA(2) + assert inst.a == 2 + assert inst.b == 3 + assert inst.c == 2 + assert inst.d == 1 + + inst = TestA(2,3,4,5) + assert inst.a == 2 + assert inst.b == 3 + assert inst.c == 4 + assert inst.d == 5 + + with self.assertRaises(TypeError): + inst = TestA(3,4,5,a=2) + + inst = TestA(3,4,5,d=2) + assert inst.a == 3 + assert inst.b == 4 + assert inst.c == 5 + assert inst.d == 2 + + + + +if __name__ == '__main__': + unittest.main()