pix2pixHD_model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import numpy as np
  2. import torch
  3. import os
  4. from torch.autograd import Variable
  5. from util.image_pool import ImagePool
  6. from .base_model import BaseModel
  7. from . import networks
  8. class Pix2PixHDModel(BaseModel):
  9. def name(self):
  10. return 'Pix2PixHDModel'
  11. def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
  12. flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
  13. def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
  14. return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
  15. return loss_filter
  16. def initialize(self, opt):
  17. BaseModel.initialize(self, opt)
  18. if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
  19. torch.backends.cudnn.benchmark = True
  20. self.isTrain = opt.isTrain
  21. self.use_features = opt.instance_feat or opt.label_feat
  22. self.gen_features = self.use_features and not self.opt.load_features
  23. input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
  24. ##### define networks
  25. # Generator network
  26. netG_input_nc = input_nc
  27. if not opt.no_instance:
  28. netG_input_nc += 1
  29. if self.use_features:
  30. netG_input_nc += opt.feat_num
  31. self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
  32. opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
  33. opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
  34. # Discriminator network
  35. if self.isTrain:
  36. use_sigmoid = opt.no_lsgan
  37. netD_input_nc = input_nc + opt.output_nc
  38. if not opt.no_instance:
  39. netD_input_nc += 1
  40. self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
  41. opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
  42. ### Encoder network
  43. if self.gen_features:
  44. self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
  45. opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
  46. if self.opt.verbose:
  47. print('---------- Networks initialized -------------')
  48. # load networks
  49. if not self.isTrain or opt.continue_train or opt.load_pretrain:
  50. pretrained_path = '' if not self.isTrain else opt.load_pretrain
  51. self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
  52. if self.isTrain:
  53. self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
  54. if self.gen_features:
  55. self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
  56. # set loss functions and optimizers
  57. if self.isTrain:
  58. if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
  59. raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
  60. self.fake_pool = ImagePool(opt.pool_size)
  61. self.old_lr = opt.lr
  62. # define loss functions
  63. self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
  64. self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
  65. self.criterionFeat = torch.nn.L1Loss()
  66. if not opt.no_vgg_loss:
  67. self.criterionVGG = networks.VGGLoss(self.gpu_ids)
  68. # Names so we can breakout loss
  69. self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')
  70. # initialize optimizers
  71. # optimizer G
  72. if opt.niter_fix_global > 0:
  73. import sys
  74. if sys.version_info >= (3,0):
  75. finetune_list = set()
  76. else:
  77. from sets import Set
  78. finetune_list = Set()
  79. params_dict = dict(self.netG.named_parameters())
  80. params = []
  81. for key, value in params_dict.items():
  82. if key.startswith('model' + str(opt.n_local_enhancers)):
  83. params += [value]
  84. finetune_list.add(key.split('.')[0])
  85. print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
  86. print('The layers that are finetuned are ', sorted(finetune_list))
  87. else:
  88. params = list(self.netG.parameters())
  89. if self.gen_features:
  90. params += list(self.netE.parameters())
  91. self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
  92. # optimizer D
  93. params = list(self.netD.parameters())
  94. self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
  95. def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
  96. if self.opt.label_nc == 0:
  97. input_label = label_map.data.cuda()
  98. else:
  99. # create one-hot vector for label map
  100. size = label_map.size()
  101. oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
  102. input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
  103. input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
  104. if self.opt.data_type == 16:
  105. input_label = input_label.half()
  106. # get edges from instance map
  107. if not self.opt.no_instance:
  108. inst_map = inst_map.data.cuda()
  109. edge_map = self.get_edges(inst_map)
  110. input_label = torch.cat((input_label, edge_map), dim=1)
  111. input_label = Variable(input_label, volatile=infer)
  112. # real images for training
  113. if real_image is not None:
  114. real_image = Variable(real_image.data.cuda())
  115. # instance map for feature encoding
  116. if self.use_features:
  117. # get precomputed feature maps
  118. if self.opt.load_features:
  119. feat_map = Variable(feat_map.data.cuda())
  120. if self.opt.label_feat:
  121. inst_map = label_map.cuda()
  122. return input_label, inst_map, real_image, feat_map
  123. def discriminate(self, input_label, test_image, use_pool=False):
  124. input_concat = torch.cat((input_label, test_image.detach()), dim=1)
  125. if use_pool:
  126. fake_query = self.fake_pool.query(input_concat)
  127. return self.netD.forward(fake_query)
  128. else:
  129. return self.netD.forward(input_concat)
  130. def forward(self, label, inst, image, feat, infer=False):
  131. # Encode Inputs
  132. input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
  133. # Fake Generation
  134. if self.use_features:
  135. if not self.opt.load_features:
  136. feat_map = self.netE.forward(real_image, inst_map)
  137. input_concat = torch.cat((input_label, feat_map), dim=1)
  138. else:
  139. input_concat = input_label
  140. fake_image = self.netG.forward(input_concat)
  141. # Fake Detection and Loss
  142. pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
  143. loss_D_fake = self.criterionGAN(pred_fake_pool, False)
  144. # Real Detection and Loss
  145. pred_real = self.discriminate(input_label, real_image)
  146. loss_D_real = self.criterionGAN(pred_real, True)
  147. # GAN loss (Fake Passability Loss)
  148. pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
  149. loss_G_GAN = self.criterionGAN(pred_fake, True)
  150. # GAN feature matching loss
  151. loss_G_GAN_Feat = 0
  152. if not self.opt.no_ganFeat_loss:
  153. feat_weights = 4.0 / (self.opt.n_layers_D + 1)
  154. D_weights = 1.0 / self.opt.num_D
  155. for i in range(self.opt.num_D):
  156. for j in range(len(pred_fake[i])-1):
  157. loss_G_GAN_Feat += D_weights * feat_weights * \
  158. self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
  159. # VGG feature matching loss
  160. loss_G_VGG = 0
  161. if not self.opt.no_vgg_loss:
  162. loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
  163. # Only return the fake_B image if necessary to save BW
  164. return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
  165. def inference(self, label, inst, image=None):
  166. # Encode Inputs
  167. image = Variable(image) if image is not None else None
  168. input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
  169. # Fake Generation
  170. if self.use_features:
  171. if self.opt.use_encoded_image:
  172. # encode the real image to get feature map
  173. feat_map = self.netE.forward(real_image, inst_map)
  174. else:
  175. # sample clusters from precomputed features
  176. feat_map = self.sample_features(inst_map)
  177. input_concat = torch.cat((input_label, feat_map), dim=1)
  178. else:
  179. input_concat = input_label
  180. if torch.__version__.startswith('0.4'):
  181. with torch.no_grad():
  182. fake_image = self.netG.forward(input_concat)
  183. else:
  184. fake_image = self.netG.forward(input_concat)
  185. return fake_image
  186. def sample_features(self, inst):
  187. # read precomputed feature clusters
  188. cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
  189. features_clustered = np.load(cluster_path, encoding='latin1').item()
  190. # randomly sample from the feature clusters
  191. inst_np = inst.cpu().numpy().astype(int)
  192. feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
  193. for i in np.unique(inst_np):
  194. label = i if i < 1000 else i//1000
  195. if label in features_clustered:
  196. feat = features_clustered[label]
  197. cluster_idx = np.random.randint(0, feat.shape[0])
  198. idx = (inst == int(i)).nonzero()
  199. for k in range(self.opt.feat_num):
  200. feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
  201. if self.opt.data_type==16:
  202. feat_map = feat_map.half()
  203. return feat_map
  204. def encode_features(self, image, inst):
  205. image = Variable(image.cuda(), volatile=True)
  206. feat_num = self.opt.feat_num
  207. h, w = inst.size()[2], inst.size()[3]
  208. block_num = 32
  209. feat_map = self.netE.forward(image, inst.cuda())
  210. inst_np = inst.cpu().numpy().astype(int)
  211. feature = {}
  212. for i in range(self.opt.label_nc):
  213. feature[i] = np.zeros((0, feat_num+1))
  214. for i in np.unique(inst_np):
  215. label = i if i < 1000 else i//1000
  216. idx = (inst == int(i)).nonzero()
  217. num = idx.size()[0]
  218. idx = idx[num//2,:]
  219. val = np.zeros((1, feat_num+1))
  220. for k in range(feat_num):
  221. val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
  222. val[0, feat_num] = float(num) / (h * w // block_num)
  223. feature[label] = np.append(feature[label], val, axis=0)
  224. return feature
  225. def get_edges(self, t):
  226. edge = torch.cuda.ByteTensor(t.size()).zero_()
  227. edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
  228. edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
  229. edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
  230. edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
  231. if self.opt.data_type==16:
  232. return edge.half()
  233. else:
  234. return edge.float()
  235. def save(self, which_epoch):
  236. self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
  237. self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
  238. if self.gen_features:
  239. self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
  240. def update_fixed_params(self):
  241. # after fixing the global generator for a number of iterations, also start finetuning it
  242. params = list(self.netG.parameters())
  243. if self.gen_features:
  244. params += list(self.netE.parameters())
  245. self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
  246. if self.opt.verbose:
  247. print('------------ Now also finetuning global generator -----------')
  248. def update_learning_rate(self):
  249. lrd = self.opt.lr / self.opt.niter_decay
  250. lr = self.old_lr - lrd
  251. for param_group in self.optimizer_D.param_groups:
  252. param_group['lr'] = lr
  253. for param_group in self.optimizer_G.param_groups:
  254. param_group['lr'] = lr
  255. if self.opt.verbose:
  256. print('update learning rate: %f -> %f' % (self.old_lr, lr))
  257. self.old_lr = lr
  258. class InferenceModel(Pix2PixHDModel):
  259. def forward(self, inp):
  260. label, inst = inp
  261. return self.inference(label, inst)