culture commited on
Commit
2cce121
·
1 Parent(s): e0804a6

Upload gfpgan/models/gfpgan_model.py

Browse files
Files changed (1) hide show
  1. gfpgan/models/gfpgan_model.py +580 -0
gfpgan/models/gfpgan_model.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os.path as osp
3
+ import torch
4
+ from basicsr.archs import build_network
5
+ from basicsr.losses import build_loss
6
+ from basicsr.losses.losses import r1_penalty
7
+ from basicsr.metrics import calculate_metric
8
+ from basicsr.models.base_model import BaseModel
9
+ from basicsr.utils import get_root_logger, imwrite, tensor2img
10
+ from basicsr.utils.registry import MODEL_REGISTRY
11
+ from collections import OrderedDict
12
+ from torch.nn import functional as F
13
+ from torchvision.ops import roi_align
14
+ from tqdm import tqdm
15
+
16
+
17
+ @MODEL_REGISTRY.register()
18
+ class GFPGANModel(BaseModel):
19
+ """The GFPGAN model for Towards real-world blind face restoratin with generative facial prior"""
20
+
21
+ def __init__(self, opt):
22
+ super(GFPGANModel, self).__init__(opt)
23
+ self.idx = 0 # it is used for saving data for check
24
+
25
+ # define network
26
+ self.net_g = build_network(opt['network_g'])
27
+ self.net_g = self.model_to_device(self.net_g)
28
+ self.print_network(self.net_g)
29
+
30
+ # load pretrained model
31
+ load_path = self.opt['path'].get('pretrain_network_g', None)
32
+ if load_path is not None:
33
+ param_key = self.opt['path'].get('param_key_g', 'params')
34
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
35
+
36
+ self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
37
+
38
+ if self.is_train:
39
+ self.init_training_settings()
40
+
41
+ def init_training_settings(self):
42
+ train_opt = self.opt['train']
43
+
44
+ # ----------- define net_d ----------- #
45
+ self.net_d = build_network(self.opt['network_d'])
46
+ self.net_d = self.model_to_device(self.net_d)
47
+ self.print_network(self.net_d)
48
+ # load pretrained model
49
+ load_path = self.opt['path'].get('pretrain_network_d', None)
50
+ if load_path is not None:
51
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52
+
53
+ # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54
+ # net_g_ema only used for testing on one GPU and saving. There is no need to wrap with DistributedDataParallel
55
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
56
+ # load pretrained model
57
+ load_path = self.opt['path'].get('pretrain_network_g', None)
58
+ if load_path is not None:
59
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
60
+ else:
61
+ self.model_ema(0) # copy net_g weight
62
+
63
+ self.net_g.train()
64
+ self.net_d.train()
65
+ self.net_g_ema.eval()
66
+
67
+ # ----------- facial component networks ----------- #
68
+ if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
69
+ self.use_facial_disc = True
70
+ else:
71
+ self.use_facial_disc = False
72
+
73
+ if self.use_facial_disc:
74
+ # left eye
75
+ self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
76
+ self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
77
+ self.print_network(self.net_d_left_eye)
78
+ load_path = self.opt['path'].get('pretrain_network_d_left_eye')
79
+ if load_path is not None:
80
+ self.load_network(self.net_d_left_eye, load_path, True, 'params')
81
+ # right eye
82
+ self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
83
+ self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
84
+ self.print_network(self.net_d_right_eye)
85
+ load_path = self.opt['path'].get('pretrain_network_d_right_eye')
86
+ if load_path is not None:
87
+ self.load_network(self.net_d_right_eye, load_path, True, 'params')
88
+ # mouth
89
+ self.net_d_mouth = build_network(self.opt['network_d_mouth'])
90
+ self.net_d_mouth = self.model_to_device(self.net_d_mouth)
91
+ self.print_network(self.net_d_mouth)
92
+ load_path = self.opt['path'].get('pretrain_network_d_mouth')
93
+ if load_path is not None:
94
+ self.load_network(self.net_d_mouth, load_path, True, 'params')
95
+
96
+ self.net_d_left_eye.train()
97
+ self.net_d_right_eye.train()
98
+ self.net_d_mouth.train()
99
+
100
+ # ----------- define facial component gan loss ----------- #
101
+ self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
102
+
103
+ # ----------- define losses ----------- #
104
+ # pixel loss
105
+ if train_opt.get('pixel_opt'):
106
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107
+ else:
108
+ self.cri_pix = None
109
+
110
+ # perceptual loss
111
+ if train_opt.get('perceptual_opt'):
112
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
113
+ else:
114
+ self.cri_perceptual = None
115
+
116
+ # L1 loss is used in pyramid loss, component style loss and identity loss
117
+ self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
118
+
119
+ # gan loss (wgan)
120
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
121
+
122
+ # ----------- define identity loss ----------- #
123
+ if 'network_identity' in self.opt:
124
+ self.use_identity = True
125
+ else:
126
+ self.use_identity = False
127
+
128
+ if self.use_identity:
129
+ # define identity network
130
+ self.network_identity = build_network(self.opt['network_identity'])
131
+ self.network_identity = self.model_to_device(self.network_identity)
132
+ self.print_network(self.network_identity)
133
+ load_path = self.opt['path'].get('pretrain_network_identity')
134
+ if load_path is not None:
135
+ self.load_network(self.network_identity, load_path, True, None)
136
+ self.network_identity.eval()
137
+ for param in self.network_identity.parameters():
138
+ param.requires_grad = False
139
+
140
+ # regularization weights
141
+ self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
142
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
143
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
144
+ self.net_d_reg_every = train_opt['net_d_reg_every']
145
+
146
+ # set up optimizers and schedulers
147
+ self.setup_optimizers()
148
+ self.setup_schedulers()
149
+
150
+ def setup_optimizers(self):
151
+ train_opt = self.opt['train']
152
+
153
+ # ----------- optimizer g ----------- #
154
+ net_g_reg_ratio = 1
155
+ normal_params = []
156
+ for _, param in self.net_g.named_parameters():
157
+ normal_params.append(param)
158
+ optim_params_g = [{ # add normal params first
159
+ 'params': normal_params,
160
+ 'lr': train_opt['optim_g']['lr']
161
+ }]
162
+ optim_type = train_opt['optim_g'].pop('type')
163
+ lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
164
+ betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
165
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
166
+ self.optimizers.append(self.optimizer_g)
167
+
168
+ # ----------- optimizer d ----------- #
169
+ net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
170
+ normal_params = []
171
+ for _, param in self.net_d.named_parameters():
172
+ normal_params.append(param)
173
+ optim_params_d = [{ # add normal params first
174
+ 'params': normal_params,
175
+ 'lr': train_opt['optim_d']['lr']
176
+ }]
177
+ optim_type = train_opt['optim_d'].pop('type')
178
+ lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
179
+ betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
180
+ self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
181
+ self.optimizers.append(self.optimizer_d)
182
+
183
+ # ----------- optimizers for facial component networks ----------- #
184
+ if self.use_facial_disc:
185
+ # setup optimizers for facial component discriminators
186
+ optim_type = train_opt['optim_component'].pop('type')
187
+ lr = train_opt['optim_component']['lr']
188
+ # left eye
189
+ self.optimizer_d_left_eye = self.get_optimizer(
190
+ optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
191
+ self.optimizers.append(self.optimizer_d_left_eye)
192
+ # right eye
193
+ self.optimizer_d_right_eye = self.get_optimizer(
194
+ optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
195
+ self.optimizers.append(self.optimizer_d_right_eye)
196
+ # mouth
197
+ self.optimizer_d_mouth = self.get_optimizer(
198
+ optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
199
+ self.optimizers.append(self.optimizer_d_mouth)
200
+
201
+ def feed_data(self, data):
202
+ self.lq = data['lq'].to(self.device)
203
+ if 'gt' in data:
204
+ self.gt = data['gt'].to(self.device)
205
+
206
+ if 'loc_left_eye' in data:
207
+ # get facial component locations, shape (batch, 4)
208
+ self.loc_left_eyes = data['loc_left_eye']
209
+ self.loc_right_eyes = data['loc_right_eye']
210
+ self.loc_mouths = data['loc_mouth']
211
+
212
+ # uncomment to check data
213
+ # import torchvision
214
+ # if self.opt['rank'] == 0:
215
+ # import os
216
+ # os.makedirs('tmp/gt', exist_ok=True)
217
+ # os.makedirs('tmp/lq', exist_ok=True)
218
+ # print(self.idx)
219
+ # torchvision.utils.save_image(
220
+ # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
221
+ # torchvision.utils.save_image(
222
+ # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
223
+ # self.idx = self.idx + 1
224
+
225
+ def construct_img_pyramid(self):
226
+ """Construct image pyramid for intermediate restoration loss"""
227
+ pyramid_gt = [self.gt]
228
+ down_img = self.gt
229
+ for _ in range(0, self.log_size - 3):
230
+ down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
231
+ pyramid_gt.insert(0, down_img)
232
+ return pyramid_gt
233
+
234
+ def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
235
+ face_ratio = int(self.opt['network_g']['out_size'] / 512)
236
+ eye_out_size *= face_ratio
237
+ mouth_out_size *= face_ratio
238
+
239
+ rois_eyes = []
240
+ rois_mouths = []
241
+ for b in range(self.loc_left_eyes.size(0)): # loop for batch size
242
+ # left eye and right eye
243
+ img_inds = self.loc_left_eyes.new_full((2, 1), b)
244
+ bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
245
+ rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
246
+ rois_eyes.append(rois)
247
+ # mouse
248
+ img_inds = self.loc_left_eyes.new_full((1, 1), b)
249
+ rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
250
+ rois_mouths.append(rois)
251
+
252
+ rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
253
+ rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
254
+
255
+ # real images
256
+ all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
257
+ self.left_eyes_gt = all_eyes[0::2, :, :, :]
258
+ self.right_eyes_gt = all_eyes[1::2, :, :, :]
259
+ self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
260
+ # output
261
+ all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
262
+ self.left_eyes = all_eyes[0::2, :, :, :]
263
+ self.right_eyes = all_eyes[1::2, :, :, :]
264
+ self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
265
+
266
+ def _gram_mat(self, x):
267
+ """Calculate Gram matrix.
268
+
269
+ Args:
270
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
271
+
272
+ Returns:
273
+ torch.Tensor: Gram matrix.
274
+ """
275
+ n, c, h, w = x.size()
276
+ features = x.view(n, c, w * h)
277
+ features_t = features.transpose(1, 2)
278
+ gram = features.bmm(features_t) / (c * h * w)
279
+ return gram
280
+
281
+ def gray_resize_for_identity(self, out, size=128):
282
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
283
+ out_gray = out_gray.unsqueeze(1)
284
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
285
+ return out_gray
286
+
287
+ def optimize_parameters(self, current_iter):
288
+ # optimize net_g
289
+ for p in self.net_d.parameters():
290
+ p.requires_grad = False
291
+ self.optimizer_g.zero_grad()
292
+
293
+ # do not update facial component net_d
294
+ if self.use_facial_disc:
295
+ for p in self.net_d_left_eye.parameters():
296
+ p.requires_grad = False
297
+ for p in self.net_d_right_eye.parameters():
298
+ p.requires_grad = False
299
+ for p in self.net_d_mouth.parameters():
300
+ p.requires_grad = False
301
+
302
+ # image pyramid loss weight
303
+ if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
304
+ pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
305
+ else:
306
+ pyramid_loss_weight = 1e-12 # very small loss
307
+ if pyramid_loss_weight > 0:
308
+ self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
309
+ pyramid_gt = self.construct_img_pyramid()
310
+ else:
311
+ self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
312
+
313
+ # get roi-align regions
314
+ if self.use_facial_disc:
315
+ self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
316
+
317
+ l_g_total = 0
318
+ loss_dict = OrderedDict()
319
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
320
+ # pixel loss
321
+ if self.cri_pix:
322
+ l_g_pix = self.cri_pix(self.output, self.gt)
323
+ l_g_total += l_g_pix
324
+ loss_dict['l_g_pix'] = l_g_pix
325
+
326
+ # image pyramid loss
327
+ if pyramid_loss_weight > 0:
328
+ for i in range(0, self.log_size - 2):
329
+ l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
330
+ l_g_total += l_pyramid
331
+ loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
332
+
333
+ # perceptual loss
334
+ if self.cri_perceptual:
335
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
336
+ if l_g_percep is not None:
337
+ l_g_total += l_g_percep
338
+ loss_dict['l_g_percep'] = l_g_percep
339
+ if l_g_style is not None:
340
+ l_g_total += l_g_style
341
+ loss_dict['l_g_style'] = l_g_style
342
+
343
+ # gan loss
344
+ fake_g_pred = self.net_d(self.output)
345
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
346
+ l_g_total += l_g_gan
347
+ loss_dict['l_g_gan'] = l_g_gan
348
+
349
+ # facial component loss
350
+ if self.use_facial_disc:
351
+ # left eye
352
+ fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
353
+ l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
354
+ l_g_total += l_g_gan
355
+ loss_dict['l_g_gan_left_eye'] = l_g_gan
356
+ # right eye
357
+ fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
358
+ l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
359
+ l_g_total += l_g_gan
360
+ loss_dict['l_g_gan_right_eye'] = l_g_gan
361
+ # mouth
362
+ fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
363
+ l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
364
+ l_g_total += l_g_gan
365
+ loss_dict['l_g_gan_mouth'] = l_g_gan
366
+
367
+ if self.opt['train'].get('comp_style_weight', 0) > 0:
368
+ # get gt feat
369
+ _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
370
+ _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
371
+ _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
372
+
373
+ def _comp_style(feat, feat_gt, criterion):
374
+ return criterion(self._gram_mat(feat[0]), self._gram_mat(
375
+ feat_gt[0].detach())) * 0.5 + criterion(
376
+ self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
377
+
378
+ # facial component style loss
379
+ comp_style_loss = 0
380
+ comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
381
+ comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
382
+ comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
383
+ comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
384
+ l_g_total += comp_style_loss
385
+ loss_dict['l_g_comp_style_loss'] = comp_style_loss
386
+
387
+ # identity loss
388
+ if self.use_identity:
389
+ identity_weight = self.opt['train']['identity_weight']
390
+ # get gray images and resize
391
+ out_gray = self.gray_resize_for_identity(self.output)
392
+ gt_gray = self.gray_resize_for_identity(self.gt)
393
+
394
+ identity_gt = self.network_identity(gt_gray).detach()
395
+ identity_out = self.network_identity(out_gray)
396
+ l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
397
+ l_g_total += l_identity
398
+ loss_dict['l_identity'] = l_identity
399
+
400
+ l_g_total.backward()
401
+ self.optimizer_g.step()
402
+
403
+ # EMA
404
+ self.model_ema(decay=0.5**(32 / (10 * 1000)))
405
+
406
+ # ----------- optimize net_d ----------- #
407
+ for p in self.net_d.parameters():
408
+ p.requires_grad = True
409
+ self.optimizer_d.zero_grad()
410
+ if self.use_facial_disc:
411
+ for p in self.net_d_left_eye.parameters():
412
+ p.requires_grad = True
413
+ for p in self.net_d_right_eye.parameters():
414
+ p.requires_grad = True
415
+ for p in self.net_d_mouth.parameters():
416
+ p.requires_grad = True
417
+ self.optimizer_d_left_eye.zero_grad()
418
+ self.optimizer_d_right_eye.zero_grad()
419
+ self.optimizer_d_mouth.zero_grad()
420
+
421
+ fake_d_pred = self.net_d(self.output.detach())
422
+ real_d_pred = self.net_d(self.gt)
423
+ l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
424
+ loss_dict['l_d'] = l_d
425
+ # In WGAN, real_score should be positive and fake_score should be negative
426
+ loss_dict['real_score'] = real_d_pred.detach().mean()
427
+ loss_dict['fake_score'] = fake_d_pred.detach().mean()
428
+ l_d.backward()
429
+
430
+ # regularization loss
431
+ if current_iter % self.net_d_reg_every == 0:
432
+ self.gt.requires_grad = True
433
+ real_pred = self.net_d(self.gt)
434
+ l_d_r1 = r1_penalty(real_pred, self.gt)
435
+ l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
436
+ loss_dict['l_d_r1'] = l_d_r1.detach().mean()
437
+ l_d_r1.backward()
438
+
439
+ self.optimizer_d.step()
440
+
441
+ # optimize facial component discriminators
442
+ if self.use_facial_disc:
443
+ # left eye
444
+ fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
445
+ real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
446
+ l_d_left_eye = self.cri_component(
447
+ real_d_pred, True, is_disc=True) + self.cri_gan(
448
+ fake_d_pred, False, is_disc=True)
449
+ loss_dict['l_d_left_eye'] = l_d_left_eye
450
+ l_d_left_eye.backward()
451
+ # right eye
452
+ fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
453
+ real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
454
+ l_d_right_eye = self.cri_component(
455
+ real_d_pred, True, is_disc=True) + self.cri_gan(
456
+ fake_d_pred, False, is_disc=True)
457
+ loss_dict['l_d_right_eye'] = l_d_right_eye
458
+ l_d_right_eye.backward()
459
+ # mouth
460
+ fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
461
+ real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
462
+ l_d_mouth = self.cri_component(
463
+ real_d_pred, True, is_disc=True) + self.cri_gan(
464
+ fake_d_pred, False, is_disc=True)
465
+ loss_dict['l_d_mouth'] = l_d_mouth
466
+ l_d_mouth.backward()
467
+
468
+ self.optimizer_d_left_eye.step()
469
+ self.optimizer_d_right_eye.step()
470
+ self.optimizer_d_mouth.step()
471
+
472
+ self.log_dict = self.reduce_loss_dict(loss_dict)
473
+
474
+ def test(self):
475
+ with torch.no_grad():
476
+ if hasattr(self, 'net_g_ema'):
477
+ self.net_g_ema.eval()
478
+ self.output, _ = self.net_g_ema(self.lq)
479
+ else:
480
+ logger = get_root_logger()
481
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
482
+ self.net_g.eval()
483
+ self.output, _ = self.net_g(self.lq)
484
+ self.net_g.train()
485
+
486
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
487
+ if self.opt['rank'] == 0:
488
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
489
+
490
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
491
+ dataset_name = dataloader.dataset.opt['name']
492
+ with_metrics = self.opt['val'].get('metrics') is not None
493
+ use_pbar = self.opt['val'].get('pbar', False)
494
+
495
+ if with_metrics:
496
+ if not hasattr(self, 'metric_results'): # only execute in the first run
497
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
498
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
499
+ self._initialize_best_metric_results(dataset_name)
500
+ # zero self.metric_results
501
+ self.metric_results = {metric: 0 for metric in self.metric_results}
502
+
503
+ metric_data = dict()
504
+ if use_pbar:
505
+ pbar = tqdm(total=len(dataloader), unit='image')
506
+
507
+ for idx, val_data in enumerate(dataloader):
508
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
509
+ self.feed_data(val_data)
510
+ self.test()
511
+
512
+ sr_img = tensor2img(self.output.detach().cpu(), min_max=(-1, 1))
513
+ metric_data['img'] = sr_img
514
+ if hasattr(self, 'gt'):
515
+ gt_img = tensor2img(self.gt.detach().cpu(), min_max=(-1, 1))
516
+ metric_data['img2'] = gt_img
517
+ del self.gt
518
+
519
+ # tentative for out of GPU memory
520
+ del self.lq
521
+ del self.output
522
+ torch.cuda.empty_cache()
523
+
524
+ if save_img:
525
+ if self.opt['is_train']:
526
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
527
+ f'{img_name}_{current_iter}.png')
528
+ else:
529
+ if self.opt['val']['suffix']:
530
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
531
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
532
+ else:
533
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
534
+ f'{img_name}_{self.opt["name"]}.png')
535
+ imwrite(sr_img, save_img_path)
536
+
537
+ if with_metrics:
538
+ # calculate metrics
539
+ for name, opt_ in self.opt['val']['metrics'].items():
540
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
541
+ if use_pbar:
542
+ pbar.update(1)
543
+ pbar.set_description(f'Test {img_name}')
544
+ if use_pbar:
545
+ pbar.close()
546
+
547
+ if with_metrics:
548
+ for metric in self.metric_results.keys():
549
+ self.metric_results[metric] /= (idx + 1)
550
+ # update the best metric result
551
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
552
+
553
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
554
+
555
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
556
+ log_str = f'Validation {dataset_name}\n'
557
+ for metric, value in self.metric_results.items():
558
+ log_str += f'\t # {metric}: {value:.4f}'
559
+ if hasattr(self, 'best_metric_results'):
560
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
561
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
562
+ log_str += '\n'
563
+
564
+ logger = get_root_logger()
565
+ logger.info(log_str)
566
+ if tb_logger:
567
+ for metric, value in self.metric_results.items():
568
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
569
+
570
+ def save(self, epoch, current_iter):
571
+ # save net_g and net_d
572
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
573
+ self.save_network(self.net_d, 'net_d', current_iter)
574
+ # save component discriminators
575
+ if self.use_facial_disc:
576
+ self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
577
+ self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
578
+ self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
579
+ # save training state
580
+ self.save_training_state(epoch, current_iter)