sczhou commited on
Commit
f730388
·
1 Parent(s): 8e2cd9b

fix some bugs in training code.

Browse files
.gitignore CHANGED
@@ -123,5 +123,9 @@ venv.bak/
123
 
124
  # project
125
  results/
 
 
 
 
126
  *_old*
127
 
 
123
 
124
  # project
125
  results/
126
+ experiments/
127
+ tb_logger/
128
+ run.sh
129
+ *debug*
130
  *_old*
131
 
basicsr/data/ffhq_blind_dataset.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import os.path as osp
6
+ from scipy.io import loadmat
7
+ from PIL import Image
8
+ import torch
9
+ import torch.utils.data as data
10
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
11
+ adjust_hue, adjust_saturation, normalize)
12
+ from basicsr.data import gaussian_kernels as gaussian_kernels
13
+ from basicsr.data.transforms import augment
14
+ from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
15
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
16
+ from basicsr.utils.registry import DATASET_REGISTRY
17
+
18
+ @DATASET_REGISTRY.register()
19
+ class FFHQBlindDataset(data.Dataset):
20
+
21
+ def __init__(self, opt):
22
+ super(FFHQBlindDataset, self).__init__()
23
+ logger = get_root_logger()
24
+ self.opt = opt
25
+ # file client (io backend)
26
+ self.file_client = None
27
+ self.io_backend_opt = opt['io_backend']
28
+
29
+ self.gt_folder = opt['dataroot_gt']
30
+ self.gt_size = opt.get('gt_size', 512)
31
+ self.in_size = opt.get('in_size', 512)
32
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
33
+
34
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
35
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
36
+
37
+ self.component_path = opt.get('component_path', None)
38
+ self.latent_gt_path = opt.get('latent_gt_path', None)
39
+
40
+ if self.component_path is not None:
41
+ self.crop_components = True
42
+ self.components_dict = torch.load(self.component_path)
43
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
44
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
45
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
46
+ else:
47
+ self.crop_components = False
48
+
49
+ if self.latent_gt_path is not None:
50
+ self.load_latent_gt = True
51
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
52
+ else:
53
+ self.load_latent_gt = False
54
+
55
+ if self.io_backend_opt['type'] == 'lmdb':
56
+ self.io_backend_opt['db_paths'] = self.gt_folder
57
+ if not self.gt_folder.endswith('.lmdb'):
58
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
59
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
60
+ self.paths = [line.split('.')[0] for line in fin]
61
+ else:
62
+ self.paths = paths_from_folder(self.gt_folder)
63
+
64
+ # inpainting mask
65
+ self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
66
+ if self.gen_inpaint_mask:
67
+ logger.info(f'generate mask ...')
68
+ # self.mask_max_angle = opt.get('mask_max_angle', 10)
69
+ # self.mask_max_len = opt.get('mask_max_len', 150)
70
+ # self.mask_max_width = opt.get('mask_max_width', 50)
71
+ # self.mask_draw_times = opt.get('mask_draw_times', 10)
72
+ # # print
73
+ # logger.info(f'mask_max_angle: {self.mask_max_angle}')
74
+ # logger.info(f'mask_max_len: {self.mask_max_len}')
75
+ # logger.info(f'mask_max_width: {self.mask_max_width}')
76
+ # logger.info(f'mask_draw_times: {self.mask_draw_times}')
77
+
78
+ # perform corrupt
79
+ self.use_corrupt = opt.get('use_corrupt', True)
80
+ self.use_motion_kernel = False
81
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
82
+
83
+ if self.use_motion_kernel:
84
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
85
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
86
+ self.motion_kernels = torch.load(motion_kernel_path)
87
+
88
+ if self.use_corrupt and not self.gen_inpaint_mask:
89
+ # degradation configurations
90
+ self.blur_kernel_size = opt['blur_kernel_size']
91
+ self.blur_sigma = opt['blur_sigma']
92
+ self.kernel_list = opt['kernel_list']
93
+ self.kernel_prob = opt['kernel_prob']
94
+ self.downsample_range = opt['downsample_range']
95
+ self.noise_range = opt['noise_range']
96
+ self.jpeg_range = opt['jpeg_range']
97
+ # print
98
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
99
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
100
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
101
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
102
+
103
+ # color jitter
104
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
105
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
106
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
107
+ if self.color_jitter_prob is not None:
108
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
109
+
110
+ # to gray
111
+ self.gray_prob = opt.get('gray_prob', 0.0)
112
+ if self.gray_prob is not None:
113
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
114
+ self.color_jitter_shift /= 255.
115
+
116
+ @staticmethod
117
+ def color_jitter(img, shift):
118
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
119
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
120
+ img = img + jitter_val
121
+ img = np.clip(img, 0, 1)
122
+ return img
123
+
124
+ @staticmethod
125
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
126
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
127
+ fn_idx = torch.randperm(4)
128
+ for fn_id in fn_idx:
129
+ if fn_id == 0 and brightness is not None:
130
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
131
+ img = adjust_brightness(img, brightness_factor)
132
+
133
+ if fn_id == 1 and contrast is not None:
134
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
135
+ img = adjust_contrast(img, contrast_factor)
136
+
137
+ if fn_id == 2 and saturation is not None:
138
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
139
+ img = adjust_saturation(img, saturation_factor)
140
+
141
+ if fn_id == 3 and hue is not None:
142
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
143
+ img = adjust_hue(img, hue_factor)
144
+ return img
145
+
146
+
147
+ def get_component_locations(self, name, status):
148
+ components_bbox = self.components_dict[name]
149
+ if status[0]: # hflip
150
+ # exchange right and left eye
151
+ tmp = components_bbox['left_eye']
152
+ components_bbox['left_eye'] = components_bbox['right_eye']
153
+ components_bbox['right_eye'] = tmp
154
+ # modify the width coordinate
155
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
156
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
157
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
158
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
159
+
160
+ locations_gt = {}
161
+ locations_in = {}
162
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
163
+ mean = components_bbox[part][0:2]
164
+ half_len = components_bbox[part][2]
165
+ if 'eye' in part:
166
+ half_len *= self.eye_enlarge_ratio
167
+ elif part == 'nose':
168
+ half_len *= self.nose_enlarge_ratio
169
+ elif part == 'mouth':
170
+ half_len *= self.mouth_enlarge_ratio
171
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
172
+ loc = torch.from_numpy(loc).float()
173
+ locations_gt[part] = loc
174
+ loc_in = loc/(self.gt_size//self.in_size)
175
+ locations_in[part] = loc_in
176
+ return locations_gt, locations_in
177
+
178
+
179
+ def __getitem__(self, index):
180
+ if self.file_client is None:
181
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
182
+
183
+ # load gt image
184
+ gt_path = self.paths[index]
185
+ name = osp.basename(gt_path)[:-4]
186
+ img_bytes = self.file_client.get(gt_path)
187
+ img_gt = imfrombytes(img_bytes, float32=True)
188
+
189
+ # random horizontal flip
190
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
191
+
192
+ if self.load_latent_gt:
193
+ if status[0]:
194
+ latent_gt = self.latent_gt_dict['hflip'][name]
195
+ else:
196
+ latent_gt = self.latent_gt_dict['orig'][name]
197
+
198
+ if self.crop_components:
199
+ locations_gt, locations_in = self.get_component_locations(name, status)
200
+
201
+ # generate in image
202
+ img_in = img_gt
203
+ if self.use_corrupt and not self.gen_inpaint_mask:
204
+ # motion blur
205
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
206
+ m_i = random.randint(0,31)
207
+ k = self.motion_kernels[f'{m_i:02d}']
208
+ img_in = cv2.filter2D(img_in,-1,k)
209
+
210
+ # gaussian blur
211
+ kernel = gaussian_kernels.random_mixed_kernels(
212
+ self.kernel_list,
213
+ self.kernel_prob,
214
+ self.blur_kernel_size,
215
+ self.blur_sigma,
216
+ self.blur_sigma,
217
+ [-math.pi, math.pi],
218
+ noise_range=None)
219
+ img_in = cv2.filter2D(img_in, -1, kernel)
220
+
221
+ # downsample
222
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
223
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
224
+
225
+ # noise
226
+ if self.noise_range is not None:
227
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
228
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
229
+ img_in = img_in + noise
230
+ img_in = np.clip(img_in, 0, 1)
231
+
232
+ # jpeg
233
+ if self.jpeg_range is not None:
234
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
235
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_p]
236
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
237
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
238
+
239
+ # resize to in_size
240
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
241
+
242
+ # if self.gen_inpaint_mask:
243
+ # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
244
+ # max_angle = self.mask_max_angle, max_len = self.mask_max_len,
245
+ # max_width = self.mask_max_width, times = self.mask_draw_times)
246
+ # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
247
+ # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
248
+
249
+ # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
250
+
251
+ if self.gen_inpaint_mask:
252
+ img_in = (img_in*255).astype('uint8')
253
+ img_in = brush_stroke_mask(Image.fromarray(img_in))
254
+ img_in = np.array(img_in) / 255.
255
+
256
+ # random color jitter (only for lq)
257
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
258
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
259
+ # random to gray (only for lq)
260
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
261
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
262
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
263
+
264
+ # BGR to RGB, HWC to CHW, numpy to tensor
265
+ img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
266
+
267
+ # random color jitter (pytorch version) (only for lq)
268
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
269
+ brightness = self.opt.get('brightness', (0.5, 1.5))
270
+ contrast = self.opt.get('contrast', (0.5, 1.5))
271
+ saturation = self.opt.get('saturation', (0, 1.5))
272
+ hue = self.opt.get('hue', (-0.1, 0.1))
273
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
274
+
275
+ # round and clip
276
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
277
+
278
+ # Set vgg range_norm=True if use the normalization here
279
+ # normalize
280
+ normalize(img_in, self.mean, self.std, inplace=True)
281
+ normalize(img_gt, self.mean, self.std, inplace=True)
282
+
283
+ return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
284
+
285
+ if self.crop_components:
286
+ return_dict['locations_in'] = locations_in
287
+ return_dict['locations_gt'] = locations_gt
288
+
289
+ if self.load_latent_gt:
290
+ return_dict['latent_gt'] = latent_gt
291
+
292
+ # if self.gen_inpaint_mask:
293
+ # return_dict['inpaint_mask'] = inpaint_mask
294
+
295
+ return return_dict
296
+
297
+
298
+ def __len__(self):
299
+ return len(self.paths)
basicsr/data/ffhq_blind_joint_dataset.py CHANGED
@@ -4,17 +4,14 @@ import random
4
  import numpy as np
5
  import os.path as osp
6
  from scipy.io import loadmat
7
- from PIL import Image, ImageDraw
8
  import torch
9
  import torch.utils.data as data
10
  from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
11
  adjust_hue, adjust_saturation, normalize)
12
  from basicsr.data import gaussian_kernels as gaussian_kernels
 
13
  from basicsr.data.data_util import paths_from_folder
14
- from basicsr.data.transforms import augment, img_rotate
15
- from basicsr.metrics.psnr_ssim import calculate_psnr
16
  from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
17
- from basicsr.utils.matlab_functions import imresize
18
  from basicsr.utils.registry import DATASET_REGISTRY
19
 
20
  @DATASET_REGISTRY.register()
 
4
  import numpy as np
5
  import os.path as osp
6
  from scipy.io import loadmat
 
7
  import torch
8
  import torch.utils.data as data
9
  from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
10
  adjust_hue, adjust_saturation, normalize)
11
  from basicsr.data import gaussian_kernels as gaussian_kernels
12
+ from basicsr.data.transforms import augment
13
  from basicsr.data.data_util import paths_from_folder
 
 
14
  from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
 
15
  from basicsr.utils.registry import DATASET_REGISTRY
16
 
17
  @DATASET_REGISTRY.register()
basicsr/models/codeformer_idx_model.py CHANGED
@@ -43,14 +43,18 @@ class CodeFormerIdxModel(SRModel):
43
  self.model_ema(0) # copy net_g weight
44
  self.net_g_ema.eval()
45
 
46
- if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
 
 
47
  self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
48
  self.hq_vqgan_fix.eval()
49
  self.generate_idx_gt = True
50
  for param in self.hq_vqgan_fix.parameters():
51
  param.requires_grad = False
52
  else:
53
- self.generate_idx_gt = False
 
 
54
 
55
  self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
56
  self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
 
43
  self.model_ema(0) # copy net_g weight
44
  self.net_g_ema.eval()
45
 
46
+ if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
47
+ self.generate_idx_gt = False
48
+ elif self.opt.get('network_vqgan', None) is not None:
49
  self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
50
  self.hq_vqgan_fix.eval()
51
  self.generate_idx_gt = True
52
  for param in self.hq_vqgan_fix.parameters():
53
  param.requires_grad = False
54
  else:
55
+ raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
56
+
57
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
58
 
59
  self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
60
  self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
basicsr/models/codeformer_joint_model.py CHANGED
@@ -46,15 +46,19 @@ class CodeFormerJointModel(SRModel):
46
  self.model_ema(0) # copy net_g weight
47
  self.net_g_ema.eval()
48
 
49
- if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
 
 
50
  self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
51
  self.hq_vqgan_fix.eval()
52
  self.generate_idx_gt = True
53
  for param in self.hq_vqgan_fix.parameters():
54
  param.requires_grad = False
55
  else:
56
- self.generate_idx_gt = False
57
-
 
 
58
  self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
59
  self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
60
  self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
 
46
  self.model_ema(0) # copy net_g weight
47
  self.net_g_ema.eval()
48
 
49
+ if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
50
+ self.generate_idx_gt = False
51
+ elif self.opt.get('network_vqgan', None) is not None:
52
  self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
53
  self.hq_vqgan_fix.eval()
54
  self.generate_idx_gt = True
55
  for param in self.hq_vqgan_fix.parameters():
56
  param.requires_grad = False
57
  else:
58
+ raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
59
+
60
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
61
+
62
  self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
63
  self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
64
  self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
options/CodeFormer_colorization.yml CHANGED
@@ -20,7 +20,6 @@ datasets:
20
  std: [0.5, 0.5, 0.5]
21
  use_hflip: true
22
  use_corrupt: true
23
- latent_gt_path: ~
24
 
25
  # large degradation in stageII
26
  blur_kernel_size: 41
@@ -39,7 +38,8 @@ datasets:
39
  color_jitter_pt_prob: 0.3
40
  gray_prob: 0.01
41
 
42
- latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
 
43
 
44
  # data loader
45
  num_worker_per_gpu: 2
@@ -69,6 +69,14 @@ network_g:
69
  fix_modules: ['quantize','generator']
70
  vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
71
 
 
 
 
 
 
 
 
 
72
  # path
73
  path:
74
  pretrain_network_g: ~
 
20
  std: [0.5, 0.5, 0.5]
21
  use_hflip: true
22
  use_corrupt: true
 
23
 
24
  # large degradation in stageII
25
  blur_kernel_size: 41
 
38
  color_jitter_pt_prob: 0.3
39
  gray_prob: 0.01
40
 
41
+ latent_gt_path: ~ # without pre-calculated latent code
42
+ # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
43
 
44
  # data loader
45
  num_worker_per_gpu: 2
 
69
  fix_modules: ['quantize','generator']
70
  vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
71
 
72
+ network_vqgan: # this config is needed if no pre-calculated latent
73
+ type: VQAutoEncoder
74
+ img_size: 512
75
+ nf: 64
76
+ ch_mult: [1, 2, 2, 4, 4, 8]
77
+ quantizer: 'nearest'
78
+ codebook_size: 1024
79
+
80
  # path
81
  path:
82
  pretrain_network_g: ~
options/CodeFormer_inpainting.yml CHANGED
@@ -22,8 +22,8 @@ datasets:
22
  use_corrupt: false
23
  gen_inpaint_mask: true
24
 
25
-
26
- latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
27
 
28
  # data loader
29
  num_worker_per_gpu: 2
@@ -53,6 +53,14 @@ network_g:
53
  fix_modules: ['quantize','generator']
54
  vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
55
 
 
 
 
 
 
 
 
 
56
  network_d:
57
  type: VQGANDiscriminator
58
  nc: 3
 
22
  use_corrupt: false
23
  gen_inpaint_mask: true
24
 
25
+ latent_gt_path: ~ # without pre-calculated latent code
26
+ # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
27
 
28
  # data loader
29
  num_worker_per_gpu: 2
 
53
  fix_modules: ['quantize','generator']
54
  vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
55
 
56
+ network_vqgan: # this config is needed if no pre-calculated latent
57
+ type: VQAutoEncoder
58
+ img_size: 512
59
+ nf: 64
60
+ ch_mult: [1, 2, 2, 4, 4, 8]
61
+ quantizer: 'nearest'
62
+ codebook_size: 1024
63
+
64
  network_d:
65
  type: VQGANDiscriminator
66
  nc: 3
options/CodeFormer_stage2.yml CHANGED
@@ -20,7 +20,6 @@ datasets:
20
  std: [0.5, 0.5, 0.5]
21
  use_hflip: true
22
  use_corrupt: true
23
- latent_gt_path: ~
24
 
25
  # large degradation in stageII
26
  blur_kernel_size: 41
@@ -33,7 +32,8 @@ datasets:
33
  noise_range: [0, 20]
34
  jpeg_range: [30, 80]
35
 
36
- latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
 
37
 
38
  # data loader
39
  num_worker_per_gpu: 2
@@ -63,6 +63,14 @@ network_g:
63
  fix_modules: ['quantize','generator']
64
  vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
65
 
 
 
 
 
 
 
 
 
66
  # path
67
  path:
68
  pretrain_network_g: ~
 
20
  std: [0.5, 0.5, 0.5]
21
  use_hflip: true
22
  use_corrupt: true
 
23
 
24
  # large degradation in stageII
25
  blur_kernel_size: 41
 
32
  noise_range: [0, 20]
33
  jpeg_range: [30, 80]
34
 
35
+ latent_gt_path: ~ # without pre-calculated latent code
36
+ # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
37
 
38
  # data loader
39
  num_worker_per_gpu: 2
 
63
  fix_modules: ['quantize','generator']
64
  vqgan_path: './experiments/pretrained_models/vqgan/vqgan_code1024.pth' # pretrained VQGAN
65
 
66
+ network_vqgan: # this config is needed if no pre-calculated latent
67
+ type: VQAutoEncoder
68
+ img_size: 512
69
+ nf: 64
70
+ ch_mult: [1, 2, 2, 4, 4, 8]
71
+ quantizer: 'nearest'
72
+ codebook_size: 1024
73
+
74
  # path
75
  path:
76
  pretrain_network_g: ~
options/CodeFormer_stage3.yml CHANGED
@@ -20,7 +20,6 @@ datasets:
20
  std: [0.5, 0.5, 0.5]
21
  use_hflip: true
22
  use_corrupt: true
23
- latent_gt_path: ~
24
 
25
  blur_kernel_size: 41
26
  use_motion_kernel: false
@@ -38,8 +37,8 @@ datasets:
38
  noise_range_large: [0, 20]
39
  jpeg_range_large: [30, 80]
40
 
41
-
42
- latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
43
 
44
  # data loader
45
  num_worker_per_gpu: 1
@@ -68,6 +67,14 @@ network_g:
68
  connect_list: ['32', '64', '128', '256']
69
  fix_modules: ['quantize','generator']
70
 
 
 
 
 
 
 
 
 
71
  network_d:
72
  type: VQGANDiscriminator
73
  nc: 3
 
20
  std: [0.5, 0.5, 0.5]
21
  use_hflip: true
22
  use_corrupt: true
 
23
 
24
  blur_kernel_size: 41
25
  use_motion_kernel: false
 
37
  noise_range_large: [0, 20]
38
  jpeg_range_large: [30, 80]
39
 
40
+ latent_gt_path: ~ # without pre-calculated latent code
41
+ # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
42
 
43
  # data loader
44
  num_worker_per_gpu: 1
 
67
  connect_list: ['32', '64', '128', '256']
68
  fix_modules: ['quantize','generator']
69
 
70
+ network_vqgan: # this config is needed if no pre-calculated latent
71
+ type: VQAutoEncoder
72
+ img_size: 512
73
+ nf: 64
74
+ ch_mult: [1, 2, 2, 4, 4, 8]
75
+ quantizer: 'nearest'
76
+ codebook_size: 1024
77
+
78
  network_d:
79
  type: VQGANDiscriminator
80
  nc: 3