diff --git a/mmedit/models/backbones/__init__.py b/mmedit/models/backbones/__init__.py index 72d5ea179..444bb01c8 100644 --- a/mmedit/models/backbones/__init__.py +++ b/mmedit/models/backbones/__init__.py @@ -12,7 +12,7 @@ SimpleEncoderDecoder) # yapf: enable from .generation_backbones import ResnetGenerator, UnetGenerator -from .sr_backbones import (EDSR, SRCNN, BasicVSRNet, EDVRNet, IconVSR, +from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR, MSRResNet, RRDBNet, TOFlow) __all__ = [ @@ -21,7 +21,7 @@ 'PConvEncoderDecoder', 'PConvEncoder', 'PConvDecoder', 'ResNetEnc', 'ResNetDec', 'ResShortcutEnc', 'ResShortcutDec', 'RRDBNet', 'DeepFillEncoder', 'HolisticIndexBlock', 'DepthwiseIndexBlock', - 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', + 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'RDN', 'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder', 'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN', 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder', diff --git a/mmedit/models/backbones/sr_backbones/__init__.py b/mmedit/models/backbones/sr_backbones/__init__.py index 22b15d41b..d96a81a9c 100644 --- a/mmedit/models/backbones/sr_backbones/__init__.py +++ b/mmedit/models/backbones/sr_backbones/__init__.py @@ -2,6 +2,7 @@ from .edsr import EDSR from .edvr_net import EDVRNet from .iconvsr import IconVSR +from .rdn import RDN from .rrdb_net import RRDBNet from .sr_resnet import MSRResNet from .srcnn import SRCNN @@ -9,5 +10,5 @@ __all__ = [ 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', - 'BasicVSRNet', 'IconVSR' + 'BasicVSRNet', 'IconVSR', 'RDN' ] diff --git a/mmedit/models/backbones/sr_backbones/rdn.py b/mmedit/models/backbones/sr_backbones/rdn.py new file mode 100644 index 000000000..b0d643fe2 --- /dev/null +++ b/mmedit/models/backbones/sr_backbones/rdn.py @@ -0,0 +1,200 @@ +import torch +from mmcv.runner import load_checkpoint +from torch import nn + +from mmedit.models.registry import BACKBONES +from mmedit.utils import get_root_logger + + +class DenseLayer(nn.Module): + """Dense layer + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, padding=3 // 2) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c_in, h, w). + + Returns: + Tensor: Forward results, tensor with shape (n, c_in+c_out, h, w). + """ + return torch.cat([x, self.relu(self.conv(x))], 1) + + +class RDB(nn.Module): + """Residual Dense Block of Residual Dense Network + + Args: + in_channels (int): Channel number of inputs. + channel_growth (int): Channels growth in each layer. + num_layers (int): Layer number in the Residual Dense Block. + """ + + def __init__(self, in_channels, channel_growth, num_layers): + super().__init__() + self.layers = nn.Sequential(*[ + DenseLayer(in_channels + channel_growth * i, channel_growth) + for i in range(num_layers) + ]) + + # local feature fusion + self.lff = nn.Conv2d( + in_channels + channel_growth * num_layers, + channel_growth, + kernel_size=1) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + return x + self.lff(self.layers(x)) # local residual learning + + +@BACKBONES.register_module() +class RDN(nn.Module): + """RDN model for single image super-resolution. + + Paper: Residual Dense Network for Image Super-Resolution + + Adapted from 'https://github.com/yjn870/RDN-pytorch.git' + 'RDN-pytorch/blob/master/models.py' + Copyright (c) 2021, JaeYun Yeo, under MIT License. + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + mid_channels (int): Channel number of intermediate features. + Default: 64. + num_blocks (int): Block number in the trunk network. Default: 16. + upscale_factor (int): Upsampling factor. Support 2^n and 3. + Default: 4. + num_layer (int): Layer number in the Residual Dense Block. + Default: 8. + channel_growth(int): Channels growth in each layer of RDB. + Default: 64. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels=64, + num_blocks=16, + upscale_factor=4, + num_layers=8, + channel_growth=64): + + super().__init__() + self.mid_channels = mid_channels + self.channel_growth = channel_growth + self.num_blocks = num_blocks + self.num_layers = num_layers + + # shallow feature extraction + self.sfe1 = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, padding=3 // 2) + self.sfe2 = nn.Conv2d( + mid_channels, mid_channels, kernel_size=3, padding=3 // 2) + + # residual dense blocks + self.rdbs = nn.ModuleList( + [RDB(self.mid_channels, self.channel_growth, self.num_layers)]) + for _ in range(self.num_blocks - 1): + self.rdbs.append( + RDB(self.channel_growth, self.channel_growth, self.num_layers)) + + # global feature fusion + self.gff = nn.Sequential( + nn.Conv2d( + self.channel_growth * self.num_blocks, + self.mid_channels, + kernel_size=1), + nn.Conv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + padding=3 // 2)) + + # up-sampling + assert 2 <= upscale_factor <= 4 + if upscale_factor == 2 or upscale_factor == 4: + self.upscale = [] + for _ in range(upscale_factor // 2): + self.upscale.extend([ + nn.Conv2d( + self.mid_channels, + self.mid_channels * (2**2), + kernel_size=3, + padding=3 // 2), + nn.PixelShuffle(2) + ]) + self.upscale = nn.Sequential(*self.upscale) + else: + self.upscale = nn.Sequential( + nn.Conv2d( + self.mid_channels, + self.mid_channels * (upscale_factor**2), + kernel_size=3, + padding=3 // 2), nn.PixelShuffle(upscale_factor)) + + self.output = nn.Conv2d( + self.mid_channels, out_channels, kernel_size=3, padding=3 // 2) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + sfe1 = self.sfe1(x) + sfe2 = self.sfe2(sfe1) + + x = sfe2 + local_features = [] + for i in range(self.num_blocks): + x = self.rdbs[i](x) + local_features.append(x) + + x = self.gff(torch.cat(local_features, 1)) + sfe1 + # global residual learning + x = self.upscale(x) + x = self.output(x) + return x + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=strict, logger=logger) + elif pretrained is None: + pass # use default initialization + else: + raise TypeError('"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') diff --git a/tests/test_rdn.py b/tests/test_rdn.py new file mode 100644 index 000000000..ccf07fa9c --- /dev/null +++ b/tests/test_rdn.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + +from mmedit.models import build_backbone + + +def test_rdn(): + + scale = 4 + + model_cfg = dict( + type='RDN', + in_channels=3, + out_channels=3, + mid_channels=64, + num_blocks=16, + upscale_factor=scale) + + # build model + model = build_backbone(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'RDN' + + # prepare data + inputs = torch.rand(1, 3, 32, 16) + targets = torch.rand(1, 3, 128, 64) + + # prepare loss + loss_function = nn.L1Loss() + + # prepare optimizer + optimizer = torch.optim.Adam(model.parameters()) + + # test on cpu + output = model(inputs) + optimizer.zero_grad() + loss = loss_function(output, targets) + loss.backward() + optimizer.step() + assert torch.is_tensor(output) + assert output.shape == targets.shape + + # test on gpu + if torch.cuda.is_available(): + model = model.cuda() + optimizer = torch.optim.Adam(model.parameters()) + inputs = inputs.cuda() + targets = targets.cuda() + output = model(inputs) + optimizer.zero_grad() + loss = loss_function(output, targets) + loss.backward() + optimizer.step() + assert torch.is_tensor(output) + assert output.shape == targets.shape