diff --git a/models/networks.py b/models/networks.py index 75b17cfa36d..c1ded38cf5c 100644 --- a/models/networks.py +++ b/models/networks.py @@ -76,9 +76,13 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in norm_layer = get_norm_layer(norm_type=norm) if netG == 'resnet_9blocks': - net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,use_deconvolution=True) + elif netG == 'resnet_9blocks+': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,use_deconvolution=False) elif netG == 'resnet_6blocks': - net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,use_deconvolution=True) + elif netG == 'resnet_6blocks+': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,use_deconvolution=False) elif netG == 'unet_128': net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_256': @@ -140,7 +144,7 @@ def __call__(self, input, target_is_real): # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect',use_deconvolution=False): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc @@ -171,12 +175,18 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d for i in range(n_downsampling): mult = 2**(n_downsampling - i) - model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), - kernel_size=3, stride=2, - padding=1, output_padding=1, - bias=use_bias), - norm_layer(int(ngf * mult / 2)), - nn.ReLU(True)] + if use_deconvolution: + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias)] + else: + model += [nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=True), + nn.ReflectionPad2d(1), + nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0)] + + model += [norm_layer(int(ngf * mult / 2)),nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()]