|
@@ -0,0 +1,845 @@
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import functools
|
|
|
+from torch.autograd import Variable
|
|
|
+import numpy as np
|
|
|
+from torchvision import transforms
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+###############################################################################
|
|
|
+# Functions
|
|
|
+###############################################################################
|
|
|
+def weights_init(m):
|
|
|
+ classname = m.__class__.__name__
|
|
|
+ if classname.find('Conv') != -1:
|
|
|
+ m.weight.data.normal_(0.0, 0.02)
|
|
|
+ elif classname.find('BatchNorm2d') != -1:
|
|
|
+ m.weight.data.normal_(1.0, 0.02)
|
|
|
+ m.bias.data.fill_(0)
|
|
|
+
|
|
|
+def get_norm_layer(norm_type='instance'):
|
|
|
+ if norm_type == 'batch':
|
|
|
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
|
|
+ elif norm_type == 'instance':
|
|
|
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
|
|
+ else:
|
|
|
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
|
|
+ return norm_layer
|
|
|
+
|
|
|
+def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
|
|
|
+ n_blocks_local=3, norm='instance', gpu_ids=[]):
|
|
|
+ norm_layer = get_norm_layer(norm_type=norm)
|
|
|
+ if netG == 'global':
|
|
|
+ netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
|
|
+ elif netG == 'local':
|
|
|
+ netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
|
|
|
+ n_local_enhancers, n_blocks_local, norm_layer)
|
|
|
+ elif netG == 'encoder':
|
|
|
+ netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
|
|
|
+ else:
|
|
|
+ raise('generator not implemented!')
|
|
|
+ print(netG)
|
|
|
+ if len(gpu_ids) > 0:
|
|
|
+ assert(torch.cuda.is_available())
|
|
|
+ netG.cuda(gpu_ids[0])
|
|
|
+ netG.apply(weights_init)
|
|
|
+ return netG
|
|
|
+
|
|
|
+def define_G_Adain(input_nc, output_nc, latent_size, ngf, netG, n_downsample_global=2, n_blocks_global=4, norm='instance', gpu_ids=[]):
|
|
|
+ norm_layer = get_norm_layer(norm_type=norm)
|
|
|
+ netG = Generator_Adain(input_nc, output_nc, latent_size, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
|
|
+ print(netG)
|
|
|
+ if len(gpu_ids) > 0:
|
|
|
+ assert(torch.cuda.is_available())
|
|
|
+ netG.cuda(gpu_ids[0])
|
|
|
+ netG.apply(weights_init)
|
|
|
+ return netG
|
|
|
+
|
|
|
+def define_G_Adain_Mask(input_nc, output_nc, latent_size, ngf, netG, n_downsample_global=2, n_blocks_global=4, norm='instance', gpu_ids=[]):
|
|
|
+ norm_layer = get_norm_layer(norm_type=norm)
|
|
|
+ netG = Generator_Adain_Mask(input_nc, output_nc, latent_size, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
|
|
+ print(netG)
|
|
|
+ if len(gpu_ids) > 0:
|
|
|
+ assert(torch.cuda.is_available())
|
|
|
+ netG.cuda(gpu_ids[0])
|
|
|
+ netG.apply(weights_init)
|
|
|
+ return netG
|
|
|
+
|
|
|
+def define_G_Adain_Upsample(input_nc, output_nc, latent_size, ngf, netG, n_downsample_global=2, n_blocks_global=4, norm='instance', gpu_ids=[]):
|
|
|
+ norm_layer = get_norm_layer(norm_type=norm)
|
|
|
+ netG = Generator_Adain_Upsample(input_nc, output_nc, latent_size, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
|
|
+ print(netG)
|
|
|
+ if len(gpu_ids) > 0:
|
|
|
+ assert(torch.cuda.is_available())
|
|
|
+ netG.cuda(gpu_ids[0])
|
|
|
+ netG.apply(weights_init)
|
|
|
+ return netG
|
|
|
+
|
|
|
+def define_G_Adain_2(input_nc, output_nc, latent_size, ngf, netG, n_downsample_global=2, n_blocks_global=4, norm='instance', gpu_ids=[]):
|
|
|
+ norm_layer = get_norm_layer(norm_type=norm)
|
|
|
+ netG = Generator_Adain_2(input_nc, output_nc, latent_size, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
|
|
+ print(netG)
|
|
|
+ if len(gpu_ids) > 0:
|
|
|
+ assert(torch.cuda.is_available())
|
|
|
+ netG.cuda(gpu_ids[0])
|
|
|
+ netG.apply(weights_init)
|
|
|
+ return netG
|
|
|
+
|
|
|
+def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
|
|
|
+ norm_layer = get_norm_layer(norm_type=norm)
|
|
|
+ netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
|
|
|
+ print(netD)
|
|
|
+ if len(gpu_ids) > 0:
|
|
|
+ assert(torch.cuda.is_available())
|
|
|
+ netD.cuda(gpu_ids[0])
|
|
|
+ netD.apply(weights_init)
|
|
|
+ return netD
|
|
|
+
|
|
|
+def print_network(net):
|
|
|
+ if isinstance(net, list):
|
|
|
+ net = net[0]
|
|
|
+ num_params = 0
|
|
|
+ for param in net.parameters():
|
|
|
+ num_params += param.numel()
|
|
|
+ print(net)
|
|
|
+ print('Total number of parameters: %d' % num_params)
|
|
|
+
|
|
|
+##############################################################################
|
|
|
+# Losses
|
|
|
+##############################################################################
|
|
|
+class GANLoss(nn.Module):
|
|
|
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
|
|
|
+ tensor=torch.FloatTensor, opt=None):
|
|
|
+ super(GANLoss, self).__init__()
|
|
|
+ self.real_label = target_real_label
|
|
|
+ self.fake_label = target_fake_label
|
|
|
+ self.real_label_tensor = None
|
|
|
+ self.fake_label_tensor = None
|
|
|
+ self.zero_tensor = None
|
|
|
+ self.Tensor = tensor
|
|
|
+ self.gan_mode = gan_mode
|
|
|
+ self.opt = opt
|
|
|
+ if gan_mode == 'ls':
|
|
|
+ pass
|
|
|
+ elif gan_mode == 'original':
|
|
|
+ pass
|
|
|
+ elif gan_mode == 'w':
|
|
|
+ pass
|
|
|
+ elif gan_mode == 'hinge':
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
|
|
|
+
|
|
|
+ def get_target_tensor(self, input, target_is_real):
|
|
|
+ if target_is_real:
|
|
|
+ if self.real_label_tensor is None:
|
|
|
+ self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
|
|
|
+ self.real_label_tensor.requires_grad_(False)
|
|
|
+ return self.real_label_tensor.expand_as(input)
|
|
|
+ else:
|
|
|
+ if self.fake_label_tensor is None:
|
|
|
+ self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
|
|
|
+ self.fake_label_tensor.requires_grad_(False)
|
|
|
+ return self.fake_label_tensor.expand_as(input)
|
|
|
+
|
|
|
+ def get_zero_tensor(self, input):
|
|
|
+ if self.zero_tensor is None:
|
|
|
+ self.zero_tensor = self.Tensor(1).fill_(0)
|
|
|
+ self.zero_tensor.requires_grad_(False)
|
|
|
+ return self.zero_tensor.expand_as(input)
|
|
|
+
|
|
|
+ def loss(self, input, target_is_real, for_discriminator=True):
|
|
|
+ if self.gan_mode == 'original': # cross entropy loss
|
|
|
+ target_tensor = self.get_target_tensor(input, target_is_real)
|
|
|
+ loss = F.binary_cross_entropy_with_logits(input, target_tensor)
|
|
|
+ return loss
|
|
|
+ elif self.gan_mode == 'ls':
|
|
|
+ target_tensor = self.get_target_tensor(input, target_is_real)
|
|
|
+ return F.mse_loss(input, target_tensor)
|
|
|
+ elif self.gan_mode == 'hinge':
|
|
|
+ if for_discriminator:
|
|
|
+ if target_is_real:
|
|
|
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
|
|
|
+ loss = -torch.mean(minval)
|
|
|
+ else:
|
|
|
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
|
|
|
+ loss = -torch.mean(minval)
|
|
|
+ else:
|
|
|
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
|
|
|
+ loss = -torch.mean(input)
|
|
|
+ return loss
|
|
|
+ else:
|
|
|
+ # wgan
|
|
|
+ if target_is_real:
|
|
|
+ return -input.mean()
|
|
|
+ else:
|
|
|
+ return input.mean()
|
|
|
+
|
|
|
+ def __call__(self, input, target_is_real, for_discriminator=True):
|
|
|
+ # computing loss is a bit complicated because |input| may not be
|
|
|
+ # a tensor, but list of tensors in case of multiscale discriminator
|
|
|
+ if isinstance(input, list):
|
|
|
+ loss = 0
|
|
|
+ for pred_i in input:
|
|
|
+ if isinstance(pred_i, list):
|
|
|
+ pred_i = pred_i[-1]
|
|
|
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
|
|
|
+ bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
|
|
|
+ new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
|
|
|
+ loss += new_loss
|
|
|
+ return loss / len(input)
|
|
|
+ else:
|
|
|
+ return self.loss(input, target_is_real, for_discriminator)
|
|
|
+
|
|
|
+class VGGLoss(nn.Module):
|
|
|
+ def __init__(self, gpu_ids):
|
|
|
+ super(VGGLoss, self).__init__()
|
|
|
+ self.vgg = Vgg19().cuda()
|
|
|
+ self.criterion = nn.L1Loss()
|
|
|
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
|
|
+
|
|
|
+ def forward(self, x, y):
|
|
|
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
|
|
+ loss = 0
|
|
|
+ for i in range(len(x_vgg)):
|
|
|
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
|
|
+ return loss
|
|
|
+
|
|
|
+##############################################################################
|
|
|
+# Generator
|
|
|
+##############################################################################
|
|
|
+class LocalEnhancer(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
|
|
|
+ n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
|
|
|
+ super(LocalEnhancer, self).__init__()
|
|
|
+ self.n_local_enhancers = n_local_enhancers
|
|
|
+
|
|
|
+ ###### global generator model #####
|
|
|
+ ngf_global = ngf * (2**n_local_enhancers)
|
|
|
+ model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
|
|
|
+ model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
|
|
|
+ self.model = nn.Sequential(*model_global)
|
|
|
+
|
|
|
+ ###### local enhancer layers #####
|
|
|
+ for n in range(1, n_local_enhancers+1):
|
|
|
+ ### downsample
|
|
|
+ ngf_global = ngf * (2**(n_local_enhancers-n))
|
|
|
+ model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
|
|
|
+ norm_layer(ngf_global), nn.ReLU(True),
|
|
|
+ nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf_global * 2), nn.ReLU(True)]
|
|
|
+ ### residual blocks
|
|
|
+ model_upsample = []
|
|
|
+ for i in range(n_blocks_local):
|
|
|
+ model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
|
|
|
+ norm_layer(ngf_global), nn.ReLU(True)]
|
|
|
+
|
|
|
+ ### final convolution
|
|
|
+ if n == n_local_enhancers:
|
|
|
+ model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+
|
|
|
+ setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
|
|
|
+ setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
|
|
|
+
|
|
|
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
|
|
+
|
|
|
+ def forward(self, input):
|
|
|
+ ### create input pyramid
|
|
|
+ input_downsampled = [input]
|
|
|
+ for i in range(self.n_local_enhancers):
|
|
|
+ input_downsampled.append(self.downsample(input_downsampled[-1]))
|
|
|
+
|
|
|
+ ### output at coarest level
|
|
|
+ output_prev = self.model(input_downsampled[-1])
|
|
|
+ ### build up one layer at a time
|
|
|
+ for n_local_enhancers in range(1, self.n_local_enhancers+1):
|
|
|
+ model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
|
|
|
+ model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
|
|
|
+ input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
|
|
|
+ output_prev = model_upsample(model_downsample(input_i) + output_prev)
|
|
|
+ return output_prev
|
|
|
+
|
|
|
+class GlobalGenerator(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
|
|
+ padding_type='reflect'):
|
|
|
+ assert(n_blocks >= 0)
|
|
|
+ super(GlobalGenerator, self).__init__()
|
|
|
+ activation = nn.ReLU(True)
|
|
|
+
|
|
|
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
|
|
+ ### downsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2**i
|
|
|
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf * mult * 2), activation]
|
|
|
+
|
|
|
+ ### resnet blocks
|
|
|
+ mult = 2**n_downsampling
|
|
|
+ for i in range(n_blocks):
|
|
|
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2**(n_downsampling - i)
|
|
|
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
|
|
+ norm_layer(int(ngf * mult / 2)), activation]
|
|
|
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+ self.model = nn.Sequential(*model)
|
|
|
+
|
|
|
+ def forward(self, input):
|
|
|
+ return self.model(input)
|
|
|
+
|
|
|
+# Define a resnet block
|
|
|
+class ResnetBlock(nn.Module):
|
|
|
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
|
|
|
+ super(ResnetBlock, self).__init__()
|
|
|
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
|
|
|
+
|
|
|
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
|
|
+ conv_block = []
|
|
|
+ p = 0
|
|
|
+ if padding_type == 'reflect':
|
|
|
+ conv_block += [nn.ReflectionPad2d(1)]
|
|
|
+ elif padding_type == 'replicate':
|
|
|
+ conv_block += [nn.ReplicationPad2d(1)]
|
|
|
+ elif padding_type == 'zero':
|
|
|
+ p = 1
|
|
|
+ else:
|
|
|
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
+
|
|
|
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
|
|
+ norm_layer(dim),
|
|
|
+ activation]
|
|
|
+ if use_dropout:
|
|
|
+ conv_block += [nn.Dropout(0.5)]
|
|
|
+
|
|
|
+ p = 0
|
|
|
+ if padding_type == 'reflect':
|
|
|
+ conv_block += [nn.ReflectionPad2d(1)]
|
|
|
+ elif padding_type == 'replicate':
|
|
|
+ conv_block += [nn.ReplicationPad2d(1)]
|
|
|
+ elif padding_type == 'zero':
|
|
|
+ p = 1
|
|
|
+ else:
|
|
|
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
|
|
+ norm_layer(dim)]
|
|
|
+
|
|
|
+ return nn.Sequential(*conv_block)
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ out = x + self.conv_block(x)
|
|
|
+ return out
|
|
|
+
|
|
|
+class InstanceNorm(nn.Module):
|
|
|
+ def __init__(self, epsilon=1e-8):
|
|
|
+ """
|
|
|
+ @notice: avoid in-place ops.
|
|
|
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
|
|
+ """
|
|
|
+ super(InstanceNorm, self).__init__()
|
|
|
+ self.epsilon = epsilon
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ x = x - torch.mean(x, (2, 3), True)
|
|
|
+ tmp = torch.mul(x, x) # or x ** 2
|
|
|
+ tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
|
|
+ return x * tmp
|
|
|
+
|
|
|
+class SpecificNorm(nn.Module):
|
|
|
+ def __init__(self, epsilon=1e-8):
|
|
|
+ """
|
|
|
+ @notice: avoid in-place ops.
|
|
|
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
|
|
+ """
|
|
|
+ super(SpecificNorm, self).__init__()
|
|
|
+ self.mean = np.array([0.485, 0.456, 0.406])
|
|
|
+ self.mean = torch.from_numpy(self.mean).float().cuda()
|
|
|
+ self.mean = self.mean.view([1, 3, 1, 1])
|
|
|
+
|
|
|
+ self.std = np.array([0.229, 0.224, 0.225])
|
|
|
+ self.std = torch.from_numpy(self.std).float().cuda()
|
|
|
+ self.std = self.std.view([1, 3, 1, 1])
|
|
|
+
|
|
|
+ def forward(self, x):
|
|
|
+ mean = self.mean.expand([1, 3, x.shape[2], x.shape[3]])
|
|
|
+ std = self.std.expand([1, 3, x.shape[2], x.shape[3]])
|
|
|
+
|
|
|
+ x = (x - mean) / std
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+class ApplyStyle(nn.Module):
|
|
|
+ """
|
|
|
+ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
|
|
+ """
|
|
|
+ def __init__(self, latent_size, channels):
|
|
|
+ super(ApplyStyle, self).__init__()
|
|
|
+ self.linear = nn.Linear(latent_size, channels * 2)
|
|
|
+
|
|
|
+ def forward(self, x, latent):
|
|
|
+ style = self.linear(latent) # style => [batch_size, n_channels*2]
|
|
|
+ shape = [-1, 2, x.size(1), 1, 1]
|
|
|
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
|
|
+ x = x * (style[:, 0] + 1.) + style[:, 1]
|
|
|
+ return x
|
|
|
+
|
|
|
+class ResnetBlock_Adain(nn.Module):
|
|
|
+ def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
|
|
+ super(ResnetBlock_Adain, self).__init__()
|
|
|
+
|
|
|
+ p = 0
|
|
|
+ conv1 = []
|
|
|
+ if padding_type == 'reflect':
|
|
|
+ conv1 += [nn.ReflectionPad2d(1)]
|
|
|
+ elif padding_type == 'replicate':
|
|
|
+ conv1 += [nn.ReplicationPad2d(1)]
|
|
|
+ elif padding_type == 'zero':
|
|
|
+ p = 1
|
|
|
+ else:
|
|
|
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
+ conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
|
|
+ self.conv1 = nn.Sequential(*conv1)
|
|
|
+ self.style1 = ApplyStyle(latent_size, dim)
|
|
|
+ self.act1 = activation
|
|
|
+
|
|
|
+ p = 0
|
|
|
+ conv2 = []
|
|
|
+ if padding_type == 'reflect':
|
|
|
+ conv2 += [nn.ReflectionPad2d(1)]
|
|
|
+ elif padding_type == 'replicate':
|
|
|
+ conv2 += [nn.ReplicationPad2d(1)]
|
|
|
+ elif padding_type == 'zero':
|
|
|
+ p = 1
|
|
|
+ else:
|
|
|
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
+ conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
|
|
+ self.conv2 = nn.Sequential(*conv2)
|
|
|
+ self.style2 = ApplyStyle(latent_size, dim)
|
|
|
+
|
|
|
+
|
|
|
+ def forward(self, x, dlatents_in_slice):
|
|
|
+ y = self.conv1(x)
|
|
|
+ y = self.style1(y, dlatents_in_slice)
|
|
|
+ y = self.act1(y)
|
|
|
+ y = self.conv2(y)
|
|
|
+ y = self.style2(y, dlatents_in_slice)
|
|
|
+ out = x + y
|
|
|
+ return out
|
|
|
+
|
|
|
+class UpBlock_Adain(nn.Module):
|
|
|
+ def __init__(self, dim_in, dim_out, latent_size, padding_type, activation=nn.ReLU(True)):
|
|
|
+ super(UpBlock_Adain, self).__init__()
|
|
|
+
|
|
|
+ p = 0
|
|
|
+ conv1 = [nn.Upsample(scale_factor=2, mode='bilinear')]
|
|
|
+ if padding_type == 'reflect':
|
|
|
+ conv1 += [nn.ReflectionPad2d(1)]
|
|
|
+ elif padding_type == 'replicate':
|
|
|
+ conv1 += [nn.ReplicationPad2d(1)]
|
|
|
+ elif padding_type == 'zero':
|
|
|
+ p = 1
|
|
|
+ else:
|
|
|
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
+ conv1 += [nn.Conv2d(dim_in, dim_out, kernel_size=3, padding = p), InstanceNorm()]
|
|
|
+ self.conv1 = nn.Sequential(*conv1)
|
|
|
+ self.style1 = ApplyStyle(latent_size, dim_out)
|
|
|
+ self.act1 = activation
|
|
|
+
|
|
|
+
|
|
|
+ def forward(self, x, dlatents_in_slice):
|
|
|
+ y = self.conv1(x)
|
|
|
+ y = self.style1(y, dlatents_in_slice)
|
|
|
+ y = self.act1(y)
|
|
|
+ return y
|
|
|
+
|
|
|
+class Encoder(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
|
|
|
+ super(Encoder, self).__init__()
|
|
|
+ self.output_nc = output_nc
|
|
|
+
|
|
|
+ model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
|
|
+ norm_layer(ngf), nn.ReLU(True)]
|
|
|
+ ### downsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2**i
|
|
|
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf * mult * 2), nn.ReLU(True)]
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2**(n_downsampling - i)
|
|
|
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
|
|
+ norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
|
|
|
+
|
|
|
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+ self.model = nn.Sequential(*model)
|
|
|
+
|
|
|
+ def forward(self, input, inst):
|
|
|
+ outputs = self.model(input)
|
|
|
+
|
|
|
+ # instance-wise average pooling
|
|
|
+ outputs_mean = outputs.clone()
|
|
|
+ inst_list = np.unique(inst.cpu().numpy().astype(int))
|
|
|
+ for i in inst_list:
|
|
|
+ for b in range(input.size()[0]):
|
|
|
+ indices = (inst[b:b+1] == int(i)).nonzero() # n x 4
|
|
|
+ for j in range(self.output_nc):
|
|
|
+ output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
|
|
|
+ mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
|
|
+ outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
|
|
|
+ return outputs_mean
|
|
|
+
|
|
|
+
|
|
|
+class Generator_Adain(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, latent_size, ngf=64, n_downsampling=2, n_blocks=4, norm_layer=nn.BatchNorm2d,
|
|
|
+ padding_type='reflect'):
|
|
|
+ assert (n_blocks >= 0)
|
|
|
+ super(Generator_Adain, self).__init__()
|
|
|
+ activation = nn.ReLU(True)
|
|
|
+
|
|
|
+ Enc = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
|
|
+ ### downsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** i
|
|
|
+ Enc += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf * mult * 2), activation]
|
|
|
+ self.Encoder = nn.Sequential(*Enc)
|
|
|
+
|
|
|
+ ### resnet blocks
|
|
|
+ BN = []
|
|
|
+ mult = 2 ** n_downsampling
|
|
|
+ for i in range(n_blocks):
|
|
|
+ BN += [ResnetBlock_Adain(ngf*mult, latent_size=latent_size, padding_type=padding_type, activation=activation)]
|
|
|
+ self.BottleNeck = nn.Sequential(*BN)
|
|
|
+ '''self.ResBlockAdain1 = ResnetBlock_Adain(ngf * mult, latent_size=latent_size, padding_type=padding_type,
|
|
|
+ activation=activation)
|
|
|
+ self.ResBlockAdain2 = ResnetBlock_Adain(ngf * mult, latent_size=latent_size, padding_type=padding_type,
|
|
|
+ activation=activation)
|
|
|
+ self.ResBlockAdain3 = ResnetBlock_Adain(ngf * mult, latent_size=latent_size, padding_type=padding_type,
|
|
|
+ activation=activation)
|
|
|
+ self.ResBlockAdain4 = ResnetBlock_Adain(ngf * mult, latent_size=latent_size, padding_type=padding_type,
|
|
|
+ activation=activation)'''
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ Dec = []
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** (n_downsampling - i)
|
|
|
+ Dec += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
|
|
|
+ output_padding=1),
|
|
|
+ norm_layer(int(ngf * mult / 2)), activation]
|
|
|
+ Dec += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+
|
|
|
+ self.Decoder = nn.Sequential(*Dec)
|
|
|
+ #self.model = nn.Sequential(*model)
|
|
|
+ self.spNorm = SpecificNorm()
|
|
|
+
|
|
|
+ def forward(self, input, dlatents):
|
|
|
+ x = input
|
|
|
+ x = self.Encoder(x)
|
|
|
+
|
|
|
+
|
|
|
+ for i in range(len(self.BottleNeck)):
|
|
|
+ x = self.BottleNeck[i](x, dlatents)
|
|
|
+ '''x = self.ResBlockAdain1(x, dlatents)
|
|
|
+ x = self.ResBlockAdain2(x, dlatents)
|
|
|
+ x = self.ResBlockAdain3(x, dlatents)
|
|
|
+ x = self.ResBlockAdain4(x, dlatents)'''
|
|
|
+
|
|
|
+ x = self.Decoder(x)
|
|
|
+
|
|
|
+ x = (x + 1) / 2
|
|
|
+ x = self.spNorm(x)
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+class Generator_Adain_Mask(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, latent_size, ngf=64, n_downsampling=2, n_blocks=4, norm_layer=nn.BatchNorm2d,
|
|
|
+ padding_type='reflect'):
|
|
|
+ assert (n_blocks >= 0)
|
|
|
+ super(Generator_Adain_Mask, self).__init__()
|
|
|
+ activation = nn.ReLU(True)
|
|
|
+
|
|
|
+ Enc = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
|
|
+ ### downsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** i
|
|
|
+ Enc += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf * mult * 2), activation]
|
|
|
+ self.Encoder = nn.Sequential(*Enc)
|
|
|
+
|
|
|
+ ### resnet blocks
|
|
|
+ BN = []
|
|
|
+ mult = 2 ** n_downsampling
|
|
|
+ for i in range(n_blocks):
|
|
|
+ BN += [ResnetBlock_Adain(ngf*mult, latent_size=latent_size, padding_type=padding_type, activation=activation)]
|
|
|
+ self.BottleNeck = nn.Sequential(*BN)
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ Dec = []
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** (n_downsampling - i)
|
|
|
+ Dec += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
|
|
|
+ output_padding=1),
|
|
|
+ norm_layer(int(ngf * mult / 2)), activation]
|
|
|
+ Fake_out = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+ Mast_out = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 1, kernel_size=7, padding=0), nn.Sigmoid()]
|
|
|
+
|
|
|
+ self.Decoder = nn.Sequential(*Dec)
|
|
|
+ #self.model = nn.Sequential(*model)
|
|
|
+ self.spNorm = SpecificNorm()
|
|
|
+ self.Fake_out = nn.Sequential(*Fake_out)
|
|
|
+ self.Mask_out = nn.Sequential(*Mast_out)
|
|
|
+
|
|
|
+ def forward(self, input, dlatents):
|
|
|
+ x = input
|
|
|
+ x = self.Encoder(x)
|
|
|
+
|
|
|
+
|
|
|
+ for i in range(len(self.BottleNeck)):
|
|
|
+ x = self.BottleNeck[i](x, dlatents)
|
|
|
+
|
|
|
+ x = self.Decoder(x)
|
|
|
+
|
|
|
+ fake_out = self.Fake_out(x)
|
|
|
+ mask_out = self.Mask_out(x)
|
|
|
+
|
|
|
+ fake_out = (fake_out + 1) / 2
|
|
|
+ fake_out = self.spNorm(fake_out)
|
|
|
+
|
|
|
+ generated = fake_out * mask_out + input * (1-mask_out)
|
|
|
+
|
|
|
+ return generated, mask_out
|
|
|
+
|
|
|
+class Generator_Adain_Upsample(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, latent_size, ngf=64, n_downsampling=2, n_blocks=4, norm_layer=nn.BatchNorm2d,
|
|
|
+ padding_type='reflect'):
|
|
|
+ assert (n_blocks >= 0)
|
|
|
+ super(Generator_Adain_Upsample, self).__init__()
|
|
|
+ activation = nn.ReLU(True)
|
|
|
+
|
|
|
+ Enc = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
|
|
+ ### downsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** i
|
|
|
+ Enc += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf * mult * 2), activation]
|
|
|
+ self.Encoder = nn.Sequential(*Enc)
|
|
|
+
|
|
|
+ ### resnet blocks
|
|
|
+ BN = []
|
|
|
+ mult = 2 ** n_downsampling
|
|
|
+ for i in range(n_blocks):
|
|
|
+ BN += [ResnetBlock_Adain(ngf*mult, latent_size=latent_size, padding_type=padding_type, activation=activation)]
|
|
|
+ self.BottleNeck = nn.Sequential(*BN)
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ Dec = []
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** (n_downsampling - i)
|
|
|
+ '''Dec += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
|
|
|
+ output_padding=1),
|
|
|
+ norm_layer(int(ngf * mult / 2)), activation]'''
|
|
|
+ Dec += [nn.Upsample(scale_factor=2, mode='bilinear'),
|
|
|
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1),
|
|
|
+ norm_layer(int(ngf * mult / 2)), activation]
|
|
|
+ Dec += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+
|
|
|
+ self.Decoder = nn.Sequential(*Dec)
|
|
|
+ self.spNorm = SpecificNorm()
|
|
|
+
|
|
|
+ def forward(self, input, dlatents):
|
|
|
+ x = input
|
|
|
+ x = self.Encoder(x)
|
|
|
+
|
|
|
+
|
|
|
+ for i in range(len(self.BottleNeck)):
|
|
|
+ x = self.BottleNeck[i](x, dlatents)
|
|
|
+
|
|
|
+ x = self.Decoder(x)
|
|
|
+
|
|
|
+ x = (x + 1) / 2
|
|
|
+ x = self.spNorm(x)
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+class Generator_Adain_2(nn.Module):
|
|
|
+ def __init__(self, input_nc, output_nc, latent_size, ngf=64, n_downsampling=2, n_blocks=4, norm_layer=nn.BatchNorm2d,
|
|
|
+ padding_type='reflect'):
|
|
|
+ assert (n_blocks >= 0)
|
|
|
+ super(Generator_Adain_2, self).__init__()
|
|
|
+ activation = nn.ReLU(True)
|
|
|
+
|
|
|
+ Enc = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
|
|
+ ### downsample
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** i
|
|
|
+ Enc += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
|
|
+ norm_layer(ngf * mult * 2), activation]
|
|
|
+ self.Encoder = nn.Sequential(*Enc)
|
|
|
+
|
|
|
+ ### resnet blocks
|
|
|
+ BN = []
|
|
|
+ mult = 2 ** n_downsampling
|
|
|
+ for i in range(n_blocks):
|
|
|
+ BN += [ResnetBlock_Adain(ngf*mult, latent_size=latent_size, padding_type=padding_type, activation=activation)]
|
|
|
+ self.BottleNeck = nn.Sequential(*BN)
|
|
|
+
|
|
|
+ ### upsample
|
|
|
+ Dec = []
|
|
|
+ for i in range(n_downsampling):
|
|
|
+ mult = 2 ** (n_downsampling - i)
|
|
|
+ Dec += [UpBlock_Adain(dim_in=ngf * mult, dim_out=int(ngf * mult / 2), latent_size=latent_size, padding_type=padding_type)]
|
|
|
+ layer_out = [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
|
|
+
|
|
|
+ self.Decoder = nn.Sequential(*Dec)
|
|
|
+ #self.model = nn.Sequential(*model)
|
|
|
+ self.spNorm = SpecificNorm()
|
|
|
+ self.layer_out = nn.Sequential(*layer_out)
|
|
|
+
|
|
|
+ def forward(self, input, dlatents):
|
|
|
+ x = input
|
|
|
+ x = self.Encoder(x)
|
|
|
+
|
|
|
+
|
|
|
+ for i in range(len(self.BottleNeck)):
|
|
|
+ x = self.BottleNeck[i](x, dlatents)
|
|
|
+
|
|
|
+ for i in range(len(self.Decoder)):
|
|
|
+ x = self.Decoder[i](x, dlatents)
|
|
|
+
|
|
|
+ x = self.layer_out(x)
|
|
|
+
|
|
|
+ x = (x + 1) / 2
|
|
|
+ x = self.spNorm(x)
|
|
|
+
|
|
|
+ return x
|
|
|
+
|
|
|
+class MultiscaleDiscriminator(nn.Module):
|
|
|
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
|
|
|
+ use_sigmoid=False, num_D=3, getIntermFeat=False):
|
|
|
+ super(MultiscaleDiscriminator, self).__init__()
|
|
|
+ self.num_D = num_D
|
|
|
+ self.n_layers = n_layers
|
|
|
+ self.getIntermFeat = getIntermFeat
|
|
|
+
|
|
|
+ for i in range(num_D):
|
|
|
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
|
|
|
+ if getIntermFeat:
|
|
|
+ for j in range(n_layers+2):
|
|
|
+ setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
|
|
|
+ else:
|
|
|
+ setattr(self, 'layer'+str(i), netD.model)
|
|
|
+
|
|
|
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
|
|
+
|
|
|
+ def singleD_forward(self, model, input):
|
|
|
+ if self.getIntermFeat:
|
|
|
+ result = [input]
|
|
|
+ for i in range(len(model)):
|
|
|
+ result.append(model[i](result[-1]))
|
|
|
+ return result[1:]
|
|
|
+ else:
|
|
|
+ return [model(input)]
|
|
|
+
|
|
|
+ def forward(self, input):
|
|
|
+ num_D = self.num_D
|
|
|
+ result = []
|
|
|
+ input_downsampled = input
|
|
|
+ for i in range(num_D):
|
|
|
+ if self.getIntermFeat:
|
|
|
+ model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
|
|
|
+ else:
|
|
|
+ model = getattr(self, 'layer'+str(num_D-1-i))
|
|
|
+ result.append(self.singleD_forward(model, input_downsampled))
|
|
|
+ if i != (num_D-1):
|
|
|
+ input_downsampled = self.downsample(input_downsampled)
|
|
|
+ return result
|
|
|
+
|
|
|
+# Defines the PatchGAN discriminator with the specified arguments.
|
|
|
+class NLayerDiscriminator(nn.Module):
|
|
|
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
|
|
|
+ super(NLayerDiscriminator, self).__init__()
|
|
|
+ self.getIntermFeat = getIntermFeat
|
|
|
+ self.n_layers = n_layers
|
|
|
+
|
|
|
+ kw = 4
|
|
|
+ padw = 1
|
|
|
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
|
|
|
+
|
|
|
+ nf = ndf
|
|
|
+ for n in range(1, n_layers):
|
|
|
+ nf_prev = nf
|
|
|
+ nf = min(nf * 2, 512)
|
|
|
+ sequence += [[
|
|
|
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
|
|
+ norm_layer(nf), nn.LeakyReLU(0.2, True)
|
|
|
+ ]]
|
|
|
+
|
|
|
+ nf_prev = nf
|
|
|
+ nf = min(nf * 2, 512)
|
|
|
+ sequence += [[
|
|
|
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
|
|
+ norm_layer(nf),
|
|
|
+ nn.LeakyReLU(0.2, True)
|
|
|
+ ]]
|
|
|
+
|
|
|
+ if use_sigmoid:
|
|
|
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw), nn.Sigmoid()]]
|
|
|
+ else:
|
|
|
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
|
|
+
|
|
|
+ if getIntermFeat:
|
|
|
+ for n in range(len(sequence)):
|
|
|
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
|
|
+ else:
|
|
|
+ sequence_stream = []
|
|
|
+ for n in range(len(sequence)):
|
|
|
+ sequence_stream += sequence[n]
|
|
|
+ self.model = nn.Sequential(*sequence_stream)
|
|
|
+
|
|
|
+ def forward(self, input):
|
|
|
+ if self.getIntermFeat:
|
|
|
+ res = [input]
|
|
|
+ for n in range(self.n_layers+2):
|
|
|
+ model = getattr(self, 'model'+str(n))
|
|
|
+ res.append(model(res[-1]))
|
|
|
+ return res[1:]
|
|
|
+ else:
|
|
|
+ return self.model(input)
|
|
|
+
|
|
|
+from torchvision import models
|
|
|
+class Vgg19(torch.nn.Module):
|
|
|
+ def __init__(self, requires_grad=False):
|
|
|
+ super(Vgg19, self).__init__()
|
|
|
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
|
|
|
+ self.slice1 = torch.nn.Sequential()
|
|
|
+ self.slice2 = torch.nn.Sequential()
|
|
|
+ self.slice3 = torch.nn.Sequential()
|
|
|
+ self.slice4 = torch.nn.Sequential()
|
|
|
+ self.slice5 = torch.nn.Sequential()
|
|
|
+ for x in range(2):
|
|
|
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
|
|
+ for x in range(2, 7):
|
|
|
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
|
|
+ for x in range(7, 12):
|
|
|
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
|
|
+ for x in range(12, 21):
|
|
|
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
|
|
+ for x in range(21, 30):
|
|
|
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
|
|
+ if not requires_grad:
|
|
|
+ for param in self.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
+
|
|
|
+ def forward(self, X):
|
|
|
+ h_relu1 = self.slice1(X)
|
|
|
+ h_relu2 = self.slice2(h_relu1)
|
|
|
+ h_relu3 = self.slice3(h_relu2)
|
|
|
+ h_relu4 = self.slice4(h_relu3)
|
|
|
+ h_relu5 = self.slice5(h_relu4)
|
|
|
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
|
|
+ return out
|