Skip to content

Commit

Permalink
Add RDN. (open-mmlab#233)
Browse files Browse the repository at this point in the history
* add RDN

* Add docstring and test.

* Tiny fix.

* Tiny fix.

* Add license.

* Tiny Fix

* Tiny Fix

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored Apr 25, 2021
1 parent 4c20a4a commit 352ed6a
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SimpleEncoderDecoder)
# yapf: enable
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, SRCNN, BasicVSRNet, EDVRNet, IconVSR,
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR,
MSRResNet, RRDBNet, TOFlow)

__all__ = [
Expand All @@ -21,7 +21,7 @@
'PConvEncoderDecoder', 'PConvEncoder', 'PConvDecoder', 'ResNetEnc',
'ResNetDec', 'ResShortcutEnc', 'ResShortcutDec', 'RRDBNet',
'DeepFillEncoder', 'HolisticIndexBlock', 'DepthwiseIndexBlock',
'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR',
'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'RDN',
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from .edsr import EDSR
from .edvr_net import EDVRNet
from .iconvsr import IconVSR
from .rdn import RDN
from .rrdb_net import RRDBNet
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
from .tof import TOFlow

__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN',
'BasicVSRNet', 'IconVSR'
'BasicVSRNet', 'IconVSR', 'RDN'
]
200 changes: 200 additions & 0 deletions mmedit/models/backbones/sr_backbones/rdn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import torch
from mmcv.runner import load_checkpoint
from torch import nn

from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


class DenseLayer(nn.Module):
"""Dense layer
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=3, padding=3 // 2)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c_in, h, w).
Returns:
Tensor: Forward results, tensor with shape (n, c_in+c_out, h, w).
"""
return torch.cat([x, self.relu(self.conv(x))], 1)


class RDB(nn.Module):
"""Residual Dense Block of Residual Dense Network
Args:
in_channels (int): Channel number of inputs.
channel_growth (int): Channels growth in each layer.
num_layers (int): Layer number in the Residual Dense Block.
"""

def __init__(self, in_channels, channel_growth, num_layers):
super().__init__()
self.layers = nn.Sequential(*[
DenseLayer(in_channels + channel_growth * i, channel_growth)
for i in range(num_layers)
])

# local feature fusion
self.lff = nn.Conv2d(
in_channels + channel_growth * num_layers,
channel_growth,
kernel_size=1)

def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
return x + self.lff(self.layers(x)) # local residual learning


@BACKBONES.register_module()
class RDN(nn.Module):
"""RDN model for single image super-resolution.
Paper: Residual Dense Network for Image Super-Resolution
Adapted from 'https://github.com/yjn870/RDN-pytorch.git'
'RDN-pytorch/blob/master/models.py'
Copyright (c) 2021, JaeYun Yeo, under MIT License.
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_blocks (int): Block number in the trunk network. Default: 16.
upscale_factor (int): Upsampling factor. Support 2^n and 3.
Default: 4.
num_layer (int): Layer number in the Residual Dense Block.
Default: 8.
channel_growth(int): Channels growth in each layer of RDB.
Default: 64.
"""

def __init__(self,
in_channels,
out_channels,
mid_channels=64,
num_blocks=16,
upscale_factor=4,
num_layers=8,
channel_growth=64):

super().__init__()
self.mid_channels = mid_channels
self.channel_growth = channel_growth
self.num_blocks = num_blocks
self.num_layers = num_layers

# shallow feature extraction
self.sfe1 = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, padding=3 // 2)
self.sfe2 = nn.Conv2d(
mid_channels, mid_channels, kernel_size=3, padding=3 // 2)

# residual dense blocks
self.rdbs = nn.ModuleList(
[RDB(self.mid_channels, self.channel_growth, self.num_layers)])
for _ in range(self.num_blocks - 1):
self.rdbs.append(
RDB(self.channel_growth, self.channel_growth, self.num_layers))

# global feature fusion
self.gff = nn.Sequential(
nn.Conv2d(
self.channel_growth * self.num_blocks,
self.mid_channels,
kernel_size=1),
nn.Conv2d(
self.mid_channels,
self.mid_channels,
kernel_size=3,
padding=3 // 2))

# up-sampling
assert 2 <= upscale_factor <= 4
if upscale_factor == 2 or upscale_factor == 4:
self.upscale = []
for _ in range(upscale_factor // 2):
self.upscale.extend([
nn.Conv2d(
self.mid_channels,
self.mid_channels * (2**2),
kernel_size=3,
padding=3 // 2),
nn.PixelShuffle(2)
])
self.upscale = nn.Sequential(*self.upscale)
else:
self.upscale = nn.Sequential(
nn.Conv2d(
self.mid_channels,
self.mid_channels * (upscale_factor**2),
kernel_size=3,
padding=3 // 2), nn.PixelShuffle(upscale_factor))

self.output = nn.Conv2d(
self.mid_channels, out_channels, kernel_size=3, padding=3 // 2)

def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""

sfe1 = self.sfe1(x)
sfe2 = self.sfe2(sfe1)

x = sfe2
local_features = []
for i in range(self.num_blocks):
x = self.rdbs[i](x)
local_features.append(x)

x = self.gff(torch.cat(local_features, 1)) + sfe1
# global residual learning
x = self.upscale(x)
x = self.output(x)
return x

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
pass # use default initialization
else:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
56 changes: 56 additions & 0 deletions tests/test_rdn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn

from mmedit.models import build_backbone


def test_rdn():

scale = 4

model_cfg = dict(
type='RDN',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=16,
upscale_factor=scale)

# build model
model = build_backbone(model_cfg)

# test attributes
assert model.__class__.__name__ == 'RDN'

# prepare data
inputs = torch.rand(1, 3, 32, 16)
targets = torch.rand(1, 3, 128, 64)

# prepare loss
loss_function = nn.L1Loss()

# prepare optimizer
optimizer = torch.optim.Adam(model.parameters())

# test on cpu
output = model(inputs)
optimizer.zero_grad()
loss = loss_function(output, targets)
loss.backward()
optimizer.step()
assert torch.is_tensor(output)
assert output.shape == targets.shape

# test on gpu
if torch.cuda.is_available():
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters())
inputs = inputs.cuda()
targets = targets.cuda()
output = model(inputs)
optimizer.zero_grad()
loss = loss_function(output, targets)
loss.backward()
optimizer.step()
assert torch.is_tensor(output)
assert output.shape == targets.shape

0 comments on commit 352ed6a

Please sign in to comment.