udion commited on
Commit
aea9234
·
1 Parent(s): 974236a

fixed utils for cuda->cpu

Browse files
Files changed (1) hide show
  1. utils.py +1 -1251
utils.py CHANGED
@@ -51,1254 +51,4 @@ def ensure_checkpoint_exists(model_weights_filename):
51
  print(
52
  model_weights_filename,
53
  " not found, you may need to manually download the model weights."
54
- )
55
-
56
- ########### DeblurGAN function
57
- def get_norm_layer(norm_type='instance'):
58
- if norm_type == 'batch':
59
- norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
60
- elif norm_type == 'instance':
61
- norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
62
- else:
63
- raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
64
- return norm_layer
65
-
66
- def _array_to_batch(x):
67
- x = np.transpose(x, (2, 0, 1))
68
- x = np.expand_dims(x, 0)
69
- return torch.from_numpy(x)
70
-
71
- def get_normalize():
72
- normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
73
- normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
74
-
75
- def process(a, b):
76
- r = normalize(image=a, target=b)
77
- return r['image'], r['target']
78
-
79
- return process
80
-
81
- def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
82
- x, _ = get_normalize()(x, x)
83
- if mask is None:
84
- mask = np.ones_like(x, dtype=np.float32)
85
- else:
86
- mask = np.round(mask.astype('float32') / 255)
87
-
88
- h, w, _ = x.shape
89
- block_size = 32
90
- min_height = (h // block_size + 1) * block_size
91
- min_width = (w // block_size + 1) * block_size
92
-
93
- pad_params = {'mode': 'constant',
94
- 'constant_values': 0,
95
- 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
96
- }
97
- x = np.pad(x, **pad_params)
98
- mask = np.pad(mask, **pad_params)
99
-
100
- return map(_array_to_batch, (x, mask)), h, w
101
-
102
- def postprocess(x: torch.Tensor) -> np.ndarray:
103
- x, = x
104
- x = x.detach().cpu().float().numpy()
105
- x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
106
- return x.astype('uint8')
107
-
108
- def sorted_glob(pattern):
109
- return sorted(glob(pattern))
110
- ###########
111
-
112
- def normalize(image: np.ndarray) -> np.ndarray:
113
- """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
114
- Args:
115
- image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
116
- Returns:
117
- Normalized image data. Data range [0, 1].
118
- """
119
- return image.astype(np.float64) / 255.0
120
-
121
-
122
- def unnormalize(image: np.ndarray) -> np.ndarray:
123
- """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
124
- Args:
125
- image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
126
- Returns:
127
- Denormalized image data. Data range [0, 255].
128
- """
129
- return image.astype(np.float64) * 255.0
130
-
131
-
132
- def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
133
- """Convert ``PIL.Image`` to Tensor.
134
- Args:
135
- image (np.ndarray): The image data read by ``PIL.Image``
136
- range_norm (bool): Scale [0, 1] data to between [-1, 1]
137
- half (bool): Whether to convert torch.float32 similarly to torch.half type.
138
- Returns:
139
- Normalized image data
140
- Examples:
141
- >>> image = Image.open("image.bmp")
142
- >>> tensor_image = image2tensor(image, range_norm=False, half=False)
143
- """
144
- tensor = F.to_tensor(image)
145
-
146
- if range_norm:
147
- tensor = tensor.mul_(2.0).sub_(1.0)
148
- if half:
149
- tensor = tensor.half()
150
-
151
- return tensor
152
-
153
-
154
- def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
155
- """Converts ``torch.Tensor`` to ``PIL.Image``.
156
- Args:
157
- tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
158
- range_norm (bool): Scale [-1, 1] data to between [0, 1]
159
- half (bool): Whether to convert torch.float32 similarly to torch.half type.
160
- Returns:
161
- Convert image data to support PIL library
162
- Examples:
163
- >>> tensor = torch.randn([1, 3, 128, 128])
164
- >>> image = tensor2image(tensor, range_norm=False, half=False)
165
- """
166
- if range_norm:
167
- tensor = tensor.add_(1.0).div_(2.0)
168
- if half:
169
- tensor = tensor.half()
170
-
171
- image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
172
-
173
- return image
174
-
175
-
176
- def convert_rgb_to_y(image: Any) -> Any:
177
- """Convert RGB image or tensor image data to YCbCr(Y) format.
178
- Args:
179
- image: RGB image data read by ``PIL.Image''.
180
- Returns:
181
- Y image array data.
182
- """
183
- if type(image) == np.ndarray:
184
- return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
185
- elif type(image) == torch.Tensor:
186
- if len(image.shape) == 4:
187
- image = image.squeeze_(0)
188
- return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
189
- else:
190
- raise Exception("Unknown Type", type(image))
191
-
192
-
193
- def convert_rgb_to_ycbcr(image: Any) -> Any:
194
- """Convert RGB image or tensor image data to YCbCr format.
195
- Args:
196
- image: RGB image data read by ``PIL.Image''.
197
- Returns:
198
- YCbCr image array data.
199
- """
200
- if type(image) == np.ndarray:
201
- y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
202
- cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
203
- cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
204
- return np.array([y, cb, cr]).transpose([1, 2, 0])
205
- elif type(image) == torch.Tensor:
206
- if len(image.shape) == 4:
207
- image = image.squeeze(0)
208
- y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
209
- cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
210
- cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
211
- return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
212
- else:
213
- raise Exception("Unknown Type", type(image))
214
-
215
-
216
- def convert_ycbcr_to_rgb(image: Any) -> Any:
217
- """Convert YCbCr format image to RGB format.
218
- Args:
219
- image: YCbCr image data read by ``PIL.Image''.
220
- Returns:
221
- RGB image array data.
222
- """
223
- if type(image) == np.ndarray:
224
- r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
225
- g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
226
- b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
227
- return np.array([r, g, b]).transpose([1, 2, 0])
228
- elif type(image) == torch.Tensor:
229
- if len(image.shape) == 4:
230
- image = image.squeeze(0)
231
- r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
232
- g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
233
- b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
234
- return torch.cat([r, g, b], 0).permute(1, 2, 0)
235
- else:
236
- raise Exception("Unknown Type", type(image))
237
-
238
-
239
- def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
240
- """Cut ``PIL.Image`` in the center area of the image.
241
- Args:
242
- lr: Low-resolution image data read by ``PIL.Image``.
243
- hr: High-resolution image data read by ``PIL.Image``.
244
- image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
245
- upscale_factor (int): magnification factor.
246
- Returns:
247
- Randomly cropped low-resolution images and high-resolution images.
248
- """
249
- w, h = hr.size
250
-
251
- left = (w - image_size) // 2
252
- top = (h - image_size) // 2
253
- right = left + image_size
254
- bottom = top + image_size
255
-
256
- lr = lr.crop((left // upscale_factor,
257
- top // upscale_factor,
258
- right // upscale_factor,
259
- bottom // upscale_factor))
260
- hr = hr.crop((left, top, right, bottom))
261
-
262
- return lr, hr
263
-
264
-
265
- def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
266
- """Will ``PIL.Image`` randomly capture the specified area of the image.
267
- Args:
268
- lr: Low-resolution image data read by ``PIL.Image``.
269
- hr: High-resolution image data read by ``PIL.Image``.
270
- image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
271
- upscale_factor (int): magnification factor.
272
- Returns:
273
- Randomly cropped low-resolution images and high-resolution images.
274
- """
275
- w, h = hr.size
276
- left = torch.randint(0, w - image_size + 1, size=(1,)).item()
277
- top = torch.randint(0, h - image_size + 1, size=(1,)).item()
278
- right = left + image_size
279
- bottom = top + image_size
280
-
281
- lr = lr.crop((left // upscale_factor,
282
- top // upscale_factor,
283
- right // upscale_factor,
284
- bottom // upscale_factor))
285
- hr = hr.crop((left, top, right, bottom))
286
-
287
- return lr, hr
288
-
289
-
290
- def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
291
- """Will ``PIL.Image`` randomly rotate the image.
292
- Args:
293
- lr: Low-resolution image data read by ``PIL.Image``.
294
- hr: High-resolution image data read by ``PIL.Image``.
295
- angle (int): rotation angle, clockwise and counterclockwise rotation.
296
- Returns:
297
- Randomly rotated low-resolution images and high-resolution images.
298
- """
299
- angle = random.choice((+angle, -angle))
300
- lr = F.rotate(lr, angle)
301
- hr = F.rotate(hr, angle)
302
-
303
- return lr, hr
304
-
305
-
306
- def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
307
- """Flip the ``PIL.Image`` image horizontally randomly.
308
- Args:
309
- lr: Low-resolution image data read by ``PIL.Image``.
310
- hr: High-resolution image data read by ``PIL.Image``.
311
- p (optional, float): rollover probability. (Default: 0.5)
312
- Returns:
313
- Low-resolution image and high-resolution image after random horizontal flip.
314
- """
315
- if torch.rand(1).item() > p:
316
- lr = F.hflip(lr)
317
- hr = F.hflip(hr)
318
-
319
- return lr, hr
320
-
321
-
322
- def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
323
- """Turn the ``PIL.Image`` image upside down randomly.
324
- Args:
325
- lr: Low-resolution image data read by ``PIL.Image``.
326
- hr: High-resolution image data read by ``PIL.Image``.
327
- p (optional, float): rollover probability. (Default: 0.5)
328
- Returns:
329
- Randomly rotated up and down low-resolution images and high-resolution images.
330
- """
331
- if torch.rand(1).item() > p:
332
- lr = F.vflip(lr)
333
- hr = F.vflip(hr)
334
-
335
- return lr, hr
336
-
337
-
338
- def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
339
- """Set ``PIL.Image`` to randomly adjust the image brightness.
340
- Args:
341
- lr: Low-resolution image data read by ``PIL.Image``.
342
- hr: High-resolution image data read by ``PIL.Image``.
343
- Returns:
344
- Low-resolution image and high-resolution image with randomly adjusted brightness.
345
- """
346
- # Randomly adjust the brightness gain range.
347
- factor = random.uniform(0.5, 2)
348
- lr = F.adjust_brightness(lr, factor)
349
- hr = F.adjust_brightness(hr, factor)
350
-
351
- return lr, hr
352
-
353
-
354
- def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
355
- """Set ``PIL.Image`` to randomly adjust the image contrast.
356
- Args:
357
- lr: Low-resolution image data read by ``PIL.Image``.
358
- hr: High-resolution image data read by ``PIL.Image``.
359
- Returns:
360
- Low-resolution image and high-resolution image with randomly adjusted contrast.
361
- """
362
- # Randomly adjust the contrast gain range.
363
- factor = random.uniform(0.5, 2)
364
- lr = F.adjust_contrast(lr, factor)
365
- hr = F.adjust_contrast(hr, factor)
366
-
367
- return lr, hr
368
-
369
- #### metrics to compute -- assumes single images, i.e., tensor of 3 dims
370
- def img_mae(x1, x2):
371
- m = torch.abs(x1-x2).mean()
372
- return m
373
-
374
- def img_mse(x1, x2):
375
- m = torch.pow(torch.abs(x1-x2),2).mean()
376
- return m
377
-
378
- def img_psnr(x1, x2):
379
- m = kornia.metrics.psnr(x1, x2, 1)
380
- return m
381
-
382
- def img_ssim(x1, x2):
383
- m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
384
- m = m.mean()
385
- return m
386
-
387
- def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
388
- '''
389
- xLR/SR/HR: 3xHxW
390
- xSRvar: 1xHxW
391
- '''
392
- plt.figure(figsize=(30,10))
393
-
394
- plt.subplot(1,5,1)
395
- plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
396
- plt.axis('off')
397
-
398
- plt.subplot(1,5,2)
399
- plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
400
- plt.axis('off')
401
-
402
- plt.subplot(1,5,3)
403
- plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
404
- plt.axis('off')
405
-
406
- plt.subplot(1,5,4)
407
- error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
408
- print('error', error_map.min(), error_map.max())
409
- plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
410
- plt.clim(elim[0], elim[1])
411
- plt.axis('off')
412
-
413
- plt.subplot(1,5,5)
414
- print('uncer', xSRvar.min(), xSRvar.max())
415
- plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
416
- plt.clim(ulim[0], ulim[1])
417
- plt.axis('off')
418
-
419
- plt.subplots_adjust(wspace=0, hspace=0)
420
- plt.show()
421
-
422
- def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
423
- '''
424
- xLR/SR/HR: 3xHxW
425
- '''
426
- plt.figure(figsize=(30,10))
427
-
428
- if task != 'm':
429
- plt.subplot(1,4,1)
430
- plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
431
- plt.axis('off')
432
-
433
- plt.subplot(1,4,2)
434
- plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
435
- plt.axis('off')
436
-
437
- plt.subplot(1,4,3)
438
- plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
439
- plt.axis('off')
440
- else:
441
- plt.subplot(1,4,1)
442
- plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
443
- plt.clim(0,0.9)
444
- plt.axis('off')
445
-
446
- plt.subplot(1,4,2)
447
- plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
448
- plt.clim(0,0.9)
449
- plt.axis('off')
450
-
451
- plt.subplot(1,4,3)
452
- plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
453
- plt.clim(0,0.9)
454
- plt.axis('off')
455
-
456
- plt.subplot(1,4,4)
457
- if task == 'inpainting':
458
- error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
459
- else:
460
- error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
461
- print('error', error_map.min(), error_map.max())
462
- plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
463
- plt.clim(elim[0], elim[1])
464
- plt.axis('off')
465
-
466
- plt.subplots_adjust(wspace=0, hspace=0)
467
- plt.show()
468
-
469
- def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
470
- '''
471
- xSRvar: 1xHxW
472
- '''
473
- plt.figure(figsize=(30,10))
474
-
475
- plt.subplot(1,4,1)
476
- print('uncer', xSRvar1.min(), xSRvar1.max())
477
- plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
478
- plt.clim(ulim[0], ulim[1])
479
- plt.axis('off')
480
-
481
- plt.subplot(1,4,2)
482
- print('uncer', xSRvar2.min(), xSRvar2.max())
483
- plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
484
- plt.clim(ulim[0], ulim[1])
485
- plt.axis('off')
486
-
487
- plt.subplot(1,4,3)
488
- print('uncer', xSRvar3.min(), xSRvar3.max())
489
- plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
490
- plt.clim(ulim[0], ulim[1])
491
- plt.axis('off')
492
-
493
- plt.subplot(1,4,4)
494
- print('uncer', xSRvar4.min(), xSRvar4.max())
495
- plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
496
- plt.clim(ulim[0], ulim[1])
497
- plt.axis('off')
498
-
499
- plt.subplots_adjust(wspace=0, hspace=0)
500
- plt.show()
501
-
502
- def get_UCE(list_err, list_yout_var, num_bins=100):
503
- err_min = np.min(list_err)
504
- err_max = np.max(list_err)
505
- err_len = (err_max-err_min)/num_bins
506
- num_points = len(list_err)
507
-
508
- bin_stats = {}
509
- for i in range(num_bins):
510
- bin_stats[i] = {
511
- 'start_idx': err_min + i*err_len,
512
- 'end_idx': err_min + (i+1)*err_len,
513
- 'num_points': 0,
514
- 'mean_err': 0,
515
- 'mean_var': 0,
516
- }
517
-
518
- for e,v in zip(list_err, list_yout_var):
519
- for i in range(num_bins):
520
- if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
521
- bin_stats[i]['num_points'] += 1
522
- bin_stats[i]['mean_err'] += e
523
- bin_stats[i]['mean_var'] += v
524
-
525
- uce = 0
526
- eps = 1e-8
527
- for i in range(num_bins):
528
- bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
529
- bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
530
- bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
531
- *(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
532
- uce += bin_stats[i]['uce_bin']
533
-
534
- list_x, list_y = [], []
535
- for i in range(num_bins):
536
- if bin_stats[i]['num_points']>0:
537
- list_x.append(bin_stats[i]['mean_err'])
538
- list_y.append(bin_stats[i]['mean_var'])
539
-
540
- # sns.set_style('darkgrid')
541
- # sns.scatterplot(x=list_x, y=list_y)
542
- # sns.regplot(x=list_x, y=list_y, order=1)
543
- # plt.xlabel('MSE', fontsize=34)
544
- # plt.ylabel('Uncertainty', fontsize=34)
545
- # plt.plot(list_x, list_x, color='r')
546
- # plt.xlim(np.min(list_x), np.max(list_x))
547
- # plt.ylim(np.min(list_err), np.max(list_x))
548
- # plt.show()
549
-
550
- return bin_stats, uce
551
-
552
- ##################### training BayesCap
553
- def train_BayesCap(
554
- NetC,
555
- NetG,
556
- train_loader,
557
- eval_loader,
558
- Cri = TempCombLoss(),
559
- device='cuda',
560
- dtype=torch.cuda.FloatTensor(),
561
- init_lr=1e-4,
562
- num_epochs=100,
563
- eval_every=1,
564
- ckpt_path='../ckpt/BayesCap',
565
- T1=1e0,
566
- T2=5e-2,
567
- task=None,
568
- ):
569
- NetC.to(device)
570
- NetC.train()
571
- NetG.to(device)
572
- NetG.eval()
573
- optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
574
- optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
575
-
576
- score = -1e8
577
- all_loss = []
578
- for eph in range(num_epochs):
579
- eph_loss = 0
580
- with tqdm(train_loader, unit='batch') as tepoch:
581
- for (idx, batch) in enumerate(tepoch):
582
- if idx>2000:
583
- break
584
- tepoch.set_description('Epoch {}'.format(eph))
585
- ##
586
- xLR, xHR = batch[0].to(device), batch[1].to(device)
587
- xLR, xHR = xLR.type(dtype), xHR.type(dtype)
588
- if task == 'inpainting':
589
- xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
590
- xMask = xMask.to(device).type(dtype)
591
- # pass them through the network
592
- with torch.no_grad():
593
- if task == 'inpainting':
594
- _, xSR1 = NetG(xLR, xMask)
595
- elif task == 'depth':
596
- xSR1 = NetG(xLR)[("disp", 0)]
597
- else:
598
- xSR1 = NetG(xLR)
599
- # with torch.autograd.set_detect_anomaly(True):
600
- xSR = xSR1.clone()
601
- xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
602
- # print(xSRC_alpha)
603
- optimizer.zero_grad()
604
- if task == 'depth':
605
- loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
606
- else:
607
- loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
608
- # print(loss)
609
- loss.backward()
610
- optimizer.step()
611
- ##
612
- eph_loss += loss.item()
613
- tepoch.set_postfix(loss=loss.item())
614
- eph_loss /= len(train_loader)
615
- all_loss.append(eph_loss)
616
- print('Avg. loss: {}'.format(eph_loss))
617
- # evaluate and save the models
618
- torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
619
- if eph%eval_every == 0:
620
- curr_score = eval_BayesCap(
621
- NetC,
622
- NetG,
623
- eval_loader,
624
- device=device,
625
- dtype=dtype,
626
- task=task,
627
- )
628
- print('current score: {} | Last best score: {}'.format(curr_score, score))
629
- if curr_score >= score:
630
- score = curr_score
631
- torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
632
- optim_scheduler.step()
633
-
634
- #### get different uncertainty maps
635
- def get_uncer_BayesCap(
636
- NetC,
637
- NetG,
638
- xin,
639
- task=None,
640
- xMask=None,
641
- ):
642
- with torch.no_grad():
643
- if task == 'inpainting':
644
- _, xSR = NetG(xin, xMask)
645
- else:
646
- xSR = NetG(xin)
647
- xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
648
- a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
649
- b_map = xSRC_beta.to('cpu').data
650
- xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
651
-
652
- return xSRvar
653
-
654
- def get_uncer_TTDAp(
655
- NetG,
656
- xin,
657
- p_mag=0.05,
658
- num_runs=50,
659
- task=None,
660
- xMask=None,
661
- ):
662
- list_xSR = []
663
- with torch.no_grad():
664
- for z in range(num_runs):
665
- if task == 'inpainting':
666
- _, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
667
- else:
668
- xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
669
- list_xSR.append(xSRz)
670
- xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
671
- xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
672
- return xSRvar
673
-
674
- def get_uncer_DO(
675
- NetG,
676
- xin,
677
- dop=0.2,
678
- num_runs=50,
679
- task=None,
680
- xMask=None,
681
- ):
682
- list_xSR = []
683
- with torch.no_grad():
684
- for z in range(num_runs):
685
- if task == 'inpainting':
686
- _, xSRz = NetG(xin, xMask, dop=dop)
687
- else:
688
- xSRz = NetG(xin, dop=dop)
689
- list_xSR.append(xSRz)
690
- xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
691
- xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
692
- return xSRvar
693
-
694
- ################### Different eval functions
695
-
696
- def eval_BayesCap(
697
- NetC,
698
- NetG,
699
- eval_loader,
700
- device='cuda',
701
- dtype=torch.cuda.FloatTensor,
702
- task=None,
703
- xMask=None,
704
- ):
705
- NetC.to(device)
706
- NetC.eval()
707
- NetG.to(device)
708
- NetG.eval()
709
-
710
- mean_ssim = 0
711
- mean_psnr = 0
712
- mean_mse = 0
713
- mean_mae = 0
714
- num_imgs = 0
715
- list_error = []
716
- list_var = []
717
- with tqdm(eval_loader, unit='batch') as tepoch:
718
- for (idx, batch) in enumerate(tepoch):
719
- tepoch.set_description('Validating ...')
720
- ##
721
- xLR, xHR = batch[0].to(device), batch[1].to(device)
722
- xLR, xHR = xLR.type(dtype), xHR.type(dtype)
723
- if task == 'inpainting':
724
- if xMask==None:
725
- xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
726
- xMask = xMask.to(device).type(dtype)
727
- else:
728
- xMask = xMask.to(device).type(dtype)
729
- # pass them through the network
730
- with torch.no_grad():
731
- if task == 'inpainting':
732
- _, xSR = NetG(xLR, xMask)
733
- elif task == 'depth':
734
- xSR = NetG(xLR)[("disp", 0)]
735
- else:
736
- xSR = NetG(xLR)
737
- xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
738
- a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
739
- b_map = xSRC_beta.to('cpu').data
740
- xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
741
- n_batch = xSRC_mu.shape[0]
742
- if task == 'depth':
743
- xHR = xSR
744
- for j in range(n_batch):
745
- num_imgs += 1
746
- mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
747
- mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
748
- mean_mse += img_mse(xSRC_mu[j], xHR[j])
749
- mean_mae += img_mae(xSRC_mu[j], xHR[j])
750
-
751
- show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
752
-
753
- error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
754
- var_map = xSRvar[j].to('cpu').data.reshape(-1)
755
- list_error.extend(list(error_map.numpy()))
756
- list_var.extend(list(var_map.numpy()))
757
- ##
758
- mean_ssim /= num_imgs
759
- mean_psnr /= num_imgs
760
- mean_mse /= num_imgs
761
- mean_mae /= num_imgs
762
- print(
763
- 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
764
- (
765
- mean_ssim, mean_psnr, mean_mse, mean_mae
766
- )
767
- )
768
- # print(len(list_error), len(list_var))
769
- # print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
770
- # print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
771
- return mean_ssim
772
-
773
- def eval_TTDA_p(
774
- NetG,
775
- eval_loader,
776
- device='cuda',
777
- dtype=torch.cuda.FloatTensor,
778
- p_mag=0.05,
779
- num_runs=50,
780
- task = None,
781
- xMask = None,
782
- ):
783
- NetG.to(device)
784
- NetG.eval()
785
-
786
- mean_ssim = 0
787
- mean_psnr = 0
788
- mean_mse = 0
789
- mean_mae = 0
790
- num_imgs = 0
791
- with tqdm(eval_loader, unit='batch') as tepoch:
792
- for (idx, batch) in enumerate(tepoch):
793
- tepoch.set_description('Validating ...')
794
- ##
795
- xLR, xHR = batch[0].to(device), batch[1].to(device)
796
- xLR, xHR = xLR.type(dtype), xHR.type(dtype)
797
- # pass them through the network
798
- list_xSR = []
799
- with torch.no_grad():
800
- if task=='inpainting':
801
- _, xSR = NetG(xLR, xMask)
802
- else:
803
- xSR = NetG(xLR)
804
- for z in range(num_runs):
805
- xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
806
- list_xSR.append(xSRz)
807
- xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
808
- xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
809
- n_batch = xSR.shape[0]
810
- for j in range(n_batch):
811
- num_imgs += 1
812
- mean_ssim += img_ssim(xSR[j], xHR[j])
813
- mean_psnr += img_psnr(xSR[j], xHR[j])
814
- mean_mse += img_mse(xSR[j], xHR[j])
815
- mean_mae += img_mae(xSR[j], xHR[j])
816
-
817
- show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
818
-
819
- mean_ssim /= num_imgs
820
- mean_psnr /= num_imgs
821
- mean_mse /= num_imgs
822
- mean_mae /= num_imgs
823
- print(
824
- 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
825
- (
826
- mean_ssim, mean_psnr, mean_mse, mean_mae
827
- )
828
- )
829
-
830
- return mean_ssim
831
-
832
- def eval_DO(
833
- NetG,
834
- eval_loader,
835
- device='cuda',
836
- dtype=torch.cuda.FloatTensor,
837
- dop=0.2,
838
- num_runs=50,
839
- task=None,
840
- xMask=None,
841
- ):
842
- NetG.to(device)
843
- NetG.eval()
844
-
845
- mean_ssim = 0
846
- mean_psnr = 0
847
- mean_mse = 0
848
- mean_mae = 0
849
- num_imgs = 0
850
- with tqdm(eval_loader, unit='batch') as tepoch:
851
- for (idx, batch) in enumerate(tepoch):
852
- tepoch.set_description('Validating ...')
853
- ##
854
- xLR, xHR = batch[0].to(device), batch[1].to(device)
855
- xLR, xHR = xLR.type(dtype), xHR.type(dtype)
856
- # pass them through the network
857
- list_xSR = []
858
- with torch.no_grad():
859
- if task == 'inpainting':
860
- _, xSR = NetG(xLR, xMask)
861
- else:
862
- xSR = NetG(xLR)
863
- for z in range(num_runs):
864
- xSRz = NetG(xLR, dop=dop)
865
- list_xSR.append(xSRz)
866
- xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
867
- xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
868
- n_batch = xSR.shape[0]
869
- for j in range(n_batch):
870
- num_imgs += 1
871
- mean_ssim += img_ssim(xSR[j], xHR[j])
872
- mean_psnr += img_psnr(xSR[j], xHR[j])
873
- mean_mse += img_mse(xSR[j], xHR[j])
874
- mean_mae += img_mae(xSR[j], xHR[j])
875
-
876
- show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
877
- ##
878
- mean_ssim /= num_imgs
879
- mean_psnr /= num_imgs
880
- mean_mse /= num_imgs
881
- mean_mae /= num_imgs
882
- print(
883
- 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
884
- (
885
- mean_ssim, mean_psnr, mean_mse, mean_mae
886
- )
887
- )
888
-
889
- return mean_ssim
890
-
891
-
892
- ############### compare all function
893
- def compare_all(
894
- NetC,
895
- NetG,
896
- eval_loader,
897
- p_mag = 0.05,
898
- dop = 0.2,
899
- num_runs = 100,
900
- device='cuda',
901
- dtype=torch.cuda.FloatTensor,
902
- task=None,
903
- ):
904
- NetC.to(device)
905
- NetC.eval()
906
- NetG.to(device)
907
- NetG.eval()
908
-
909
- with tqdm(eval_loader, unit='batch') as tepoch:
910
- for (idx, batch) in enumerate(tepoch):
911
- tepoch.set_description('Comparing ...')
912
- ##
913
- xLR, xHR = batch[0].to(device), batch[1].to(device)
914
- xLR, xHR = xLR.type(dtype), xHR.type(dtype)
915
- if task == 'inpainting':
916
- xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
917
- xMask = xMask.to(device).type(dtype)
918
- # pass them through the network
919
- with torch.no_grad():
920
- if task == 'inpainting':
921
- _, xSR = NetG(xLR, xMask)
922
- else:
923
- xSR = NetG(xLR)
924
- xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
925
-
926
- if task == 'inpainting':
927
- xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
928
- xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
929
- xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
930
- else:
931
- xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
932
- xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
933
- xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
934
-
935
- print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
936
-
937
- n_batch = xSR.shape[0]
938
- for j in range(n_batch):
939
- if task=='s':
940
- show_SR_w_err(xLR[j], xHR[j], xSR[j])
941
- show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
942
- show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
943
- if task=='d':
944
- show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
945
- show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
946
- show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
947
- if task=='inpainting':
948
- show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
949
- show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
950
- show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
951
- if task=='m':
952
- show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
953
- show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
954
- show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
955
-
956
-
957
- ################# Degrading Identity
958
- def degrage_BayesCap_p(
959
- NetC,
960
- NetG,
961
- eval_loader,
962
- device='cuda',
963
- dtype=torch.cuda.FloatTensor,
964
- num_runs=50,
965
- ):
966
- NetC.to(device)
967
- NetC.eval()
968
- NetG.to(device)
969
- NetG.eval()
970
-
971
- p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
972
- list_s = []
973
- list_p = []
974
- list_u1 = []
975
- list_u2 = []
976
- list_c = []
977
- for p_mag in p_mag_list:
978
- mean_ssim = 0
979
- mean_psnr = 0
980
- mean_mse = 0
981
- mean_mae = 0
982
- num_imgs = 0
983
- list_error = []
984
- list_error2 = []
985
- list_var = []
986
-
987
- with tqdm(eval_loader, unit='batch') as tepoch:
988
- for (idx, batch) in enumerate(tepoch):
989
- tepoch.set_description('Validating ...')
990
- ##
991
- xLR, xHR = batch[0].to(device), batch[1].to(device)
992
- xLR, xHR = xLR.type(dtype), xHR.type(dtype)
993
- # pass them through the network
994
- with torch.no_grad():
995
- xSR = NetG(xLR)
996
- xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
997
- a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
998
- b_map = xSRC_beta.to('cpu').data
999
- xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
1000
- n_batch = xSRC_mu.shape[0]
1001
- for j in range(n_batch):
1002
- num_imgs += 1
1003
- mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
1004
- mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
1005
- mean_mse += img_mse(xSRC_mu[j], xSR[j])
1006
- mean_mae += img_mae(xSRC_mu[j], xSR[j])
1007
-
1008
- error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
1009
- error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
1010
- var_map = xSRvar[j].to('cpu').data.reshape(-1)
1011
- list_error.extend(list(error_map.numpy()))
1012
- list_error2.extend(list(error_map2.numpy()))
1013
- list_var.extend(list(var_map.numpy()))
1014
- ##
1015
- mean_ssim /= num_imgs
1016
- mean_psnr /= num_imgs
1017
- mean_mse /= num_imgs
1018
- mean_mae /= num_imgs
1019
- print(
1020
- 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
1021
- (
1022
- mean_ssim, mean_psnr, mean_mse, mean_mae
1023
- )
1024
- )
1025
- uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
1026
- uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
1027
- print('UCE1: ', uce1)
1028
- print('UCE2: ', uce2)
1029
- list_s.append(mean_ssim.item())
1030
- list_p.append(mean_psnr.item())
1031
- list_u1.append(uce1)
1032
- list_u2.append(uce2)
1033
-
1034
- plt.plot(list_s)
1035
- plt.show()
1036
- plt.plot(list_p)
1037
- plt.show()
1038
-
1039
- plt.plot(list_u1, label='wrt SR output')
1040
- plt.plot(list_u2, label='wrt BayesCap output')
1041
- plt.legend()
1042
- plt.show()
1043
-
1044
- sns.set_style('darkgrid')
1045
- fig,ax = plt.subplots()
1046
- # make a plot
1047
- ax.plot(p_mag_list, list_s, color="red", marker="o")
1048
- # set x-axis label
1049
- ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
1050
- # set y-axis label
1051
- ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
1052
-
1053
- # twin object for two different y-axis on the sample plot
1054
- ax2=ax.twinx()
1055
- # make a plot with different y-axis using second axis object
1056
- ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
1057
- ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
1058
- ax2.set_ylabel("UCE", color="green", fontsize=10)
1059
- plt.legend(fontsize=10)
1060
- plt.tight_layout()
1061
- plt.show()
1062
-
1063
- ################# DeepFill_v2
1064
-
1065
- # ----------------------------------------
1066
- # PATH processing
1067
- # ----------------------------------------
1068
- def text_readlines(filename):
1069
- # Try to read a txt file and return a list.Return [] if there was a mistake.
1070
- try:
1071
- file = open(filename, 'r')
1072
- except IOError:
1073
- error = []
1074
- return error
1075
- content = file.readlines()
1076
- # This for loop deletes the EOF (like \n)
1077
- for i in range(len(content)):
1078
- content[i] = content[i][:len(content[i])-1]
1079
- file.close()
1080
- return content
1081
-
1082
- def savetxt(name, loss_log):
1083
- np_loss_log = np.array(loss_log)
1084
- np.savetxt(name, np_loss_log)
1085
-
1086
- def get_files(path):
1087
- # read a folder, return the complete path
1088
- ret = []
1089
- for root, dirs, files in os.walk(path):
1090
- for filespath in files:
1091
- ret.append(os.path.join(root, filespath))
1092
- return ret
1093
-
1094
- def get_names(path):
1095
- # read a folder, return the image name
1096
- ret = []
1097
- for root, dirs, files in os.walk(path):
1098
- for filespath in files:
1099
- ret.append(filespath)
1100
- return ret
1101
-
1102
- def text_save(content, filename, mode = 'a'):
1103
- # save a list to a txt
1104
- # Try to save a list variable in txt file.
1105
- file = open(filename, mode)
1106
- for i in range(len(content)):
1107
- file.write(str(content[i]) + '\n')
1108
- file.close()
1109
-
1110
- def check_path(path):
1111
- if not os.path.exists(path):
1112
- os.makedirs(path)
1113
-
1114
- # ----------------------------------------
1115
- # Validation and Sample at training
1116
- # ----------------------------------------
1117
- def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
1118
- # Save image one-by-one
1119
- for i in range(len(img_list)):
1120
- img = img_list[i]
1121
- # Recover normalization: * 255 because last layer is sigmoid activated
1122
- img = img * 255
1123
- # Process img_copy and do not destroy the data of img
1124
- img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
1125
- img_copy = np.clip(img_copy, 0, pixel_max_cnt)
1126
- img_copy = img_copy.astype(np.uint8)
1127
- img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
1128
- # Save to certain path
1129
- save_img_name = sample_name + '_' + name_list[i] + '.jpg'
1130
- save_img_path = os.path.join(sample_folder, save_img_name)
1131
- cv2.imwrite(save_img_path, img_copy)
1132
-
1133
- def psnr(pred, target, pixel_max_cnt = 255):
1134
- mse = torch.mul(target - pred, target - pred)
1135
- rmse_avg = (torch.mean(mse).item()) ** 0.5
1136
- p = 20 * np.log10(pixel_max_cnt / rmse_avg)
1137
- return p
1138
-
1139
- def grey_psnr(pred, target, pixel_max_cnt = 255):
1140
- pred = torch.sum(pred, dim = 0)
1141
- target = torch.sum(target, dim = 0)
1142
- mse = torch.mul(target - pred, target - pred)
1143
- rmse_avg = (torch.mean(mse).item()) ** 0.5
1144
- p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
1145
- return p
1146
-
1147
- def ssim(pred, target):
1148
- pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1149
- target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1150
- target = target[0]
1151
- pred = pred[0]
1152
- ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
1153
- return ssim
1154
-
1155
- ## for contextual attention
1156
-
1157
- def extract_image_patches(images, ksizes, strides, rates, padding='same'):
1158
- """
1159
- Extract patches from images and put them in the C output dimension.
1160
- :param padding:
1161
- :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
1162
- :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
1163
- each dimension of images
1164
- :param strides: [stride_rows, stride_cols]
1165
- :param rates: [dilation_rows, dilation_cols]
1166
- :return: A Tensor
1167
- """
1168
- assert len(images.size()) == 4
1169
- assert padding in ['same', 'valid']
1170
- batch_size, channel, height, width = images.size()
1171
-
1172
- if padding == 'same':
1173
- images = same_padding(images, ksizes, strides, rates)
1174
- elif padding == 'valid':
1175
- pass
1176
- else:
1177
- raise NotImplementedError('Unsupported padding type: {}.\
1178
- Only "same" or "valid" are supported.'.format(padding))
1179
-
1180
- unfold = torch.nn.Unfold(kernel_size=ksizes,
1181
- dilation=rates,
1182
- padding=0,
1183
- stride=strides)
1184
- patches = unfold(images)
1185
- return patches # [N, C*k*k, L], L is the total number of such blocks
1186
-
1187
- def same_padding(images, ksizes, strides, rates):
1188
- assert len(images.size()) == 4
1189
- batch_size, channel, rows, cols = images.size()
1190
- out_rows = (rows + strides[0] - 1) // strides[0]
1191
- out_cols = (cols + strides[1] - 1) // strides[1]
1192
- effective_k_row = (ksizes[0] - 1) * rates[0] + 1
1193
- effective_k_col = (ksizes[1] - 1) * rates[1] + 1
1194
- padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
1195
- padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
1196
- # Pad the input
1197
- padding_top = int(padding_rows / 2.)
1198
- padding_left = int(padding_cols / 2.)
1199
- padding_bottom = padding_rows - padding_top
1200
- padding_right = padding_cols - padding_left
1201
- paddings = (padding_left, padding_right, padding_top, padding_bottom)
1202
- images = torch.nn.ZeroPad2d(paddings)(images)
1203
- return images
1204
-
1205
- def reduce_mean(x, axis=None, keepdim=False):
1206
- if not axis:
1207
- axis = range(len(x.shape))
1208
- for i in sorted(axis, reverse=True):
1209
- x = torch.mean(x, dim=i, keepdim=keepdim)
1210
- return x
1211
-
1212
-
1213
- def reduce_std(x, axis=None, keepdim=False):
1214
- if not axis:
1215
- axis = range(len(x.shape))
1216
- for i in sorted(axis, reverse=True):
1217
- x = torch.std(x, dim=i, keepdim=keepdim)
1218
- return x
1219
-
1220
-
1221
- def reduce_sum(x, axis=None, keepdim=False):
1222
- if not axis:
1223
- axis = range(len(x.shape))
1224
- for i in sorted(axis, reverse=True):
1225
- x = torch.sum(x, dim=i, keepdim=keepdim)
1226
- return x
1227
-
1228
- def random_mask(num_batch=1, mask_shape=(256,256)):
1229
- list_mask = []
1230
- for _ in range(num_batch):
1231
- # rectangle mask
1232
- image_height = mask_shape[0]
1233
- image_width = mask_shape[1]
1234
- max_delta_height = image_height//8
1235
- max_delta_width = image_width//8
1236
- height = image_height//4
1237
- width = image_width//4
1238
- max_t = image_height - height
1239
- max_l = image_width - width
1240
- t = random.randint(0, max_t)
1241
- l = random.randint(0, max_l)
1242
- # bbox = (t, l, height, width)
1243
- h = random.randint(0, max_delta_height//2)
1244
- w = random.randint(0, max_delta_width//2)
1245
- mask = torch.zeros((1, 1, image_height, image_width))
1246
- mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
1247
- rect_mask = mask
1248
-
1249
- # brush mask
1250
- min_num_vertex = 4
1251
- max_num_vertex = 12
1252
- mean_angle = 2 * math.pi / 5
1253
- angle_range = 2 * math.pi / 15
1254
- min_width = 12
1255
- max_width = 40
1256
- H, W = image_height, image_width
1257
- average_radius = math.sqrt(H*H+W*W) / 8
1258
- mask = Image.new('L', (W, H), 0)
1259
-
1260
- for _ in range(np.random.randint(1, 4)):
1261
- num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
1262
- angle_min = mean_angle - np.random.uniform(0, angle_range)
1263
- angle_max = mean_angle + np.random.uniform(0, angle_range)
1264
- angles = []
1265
- vertex = []
1266
- for i in range(num_vertex):
1267
- if i % 2 == 0:
1268
- angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
1269
- else:
1270
- angles.append(np.random.uniform(angle_min, angle_max))
1271
-
1272
- h, w = mask.size
1273
- vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
1274
- for i in range(num_vertex):
1275
- r = np.clip(
1276
- np.random.normal(loc=average_radius, scale=average_radius//2),
1277
- 0, 2*average_radius)
1278
- new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
1279
- new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
1280
- vertex.append((int(new_x), int(new_y)))
1281
-
1282
- draw = ImageDraw.Draw(mask)
1283
- width = int(np.random.uniform(min_width, max_width))
1284
- draw.line(vertex, fill=255, width=width)
1285
- for v in vertex:
1286
- draw.ellipse((v[0] - width//2,
1287
- v[1] - width//2,
1288
- v[0] + width//2,
1289
- v[1] + width//2),
1290
- fill=255)
1291
-
1292
- if np.random.normal() > 0:
1293
- mask.transpose(Image.FLIP_LEFT_RIGHT)
1294
- if np.random.normal() > 0:
1295
- mask.transpose(Image.FLIP_TOP_BOTTOM)
1296
-
1297
- mask = transforms.ToTensor()(mask)
1298
- mask = mask.reshape((1, 1, H, W))
1299
- brush_mask = mask
1300
-
1301
- mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
1302
- list_mask.append(mask)
1303
- mask = torch.cat(list_mask, dim=0)
1304
- return mask
 
51
  print(
52
  model_weights_filename,
53
  " not found, you may need to manually download the model weights."
54
+ )